[Speculative Decoding][Cherry Pick]Update extract_mtp_weight script and optimize config (#5213)

* update extract_mtp_model

* modify config usage
This commit is contained in:
freeliuzc
2025-11-25 14:42:55 +08:00
committed by GitHub
parent e581b7d7d9
commit a11d17cee9
4 changed files with 43 additions and 6 deletions

View File

@@ -17,7 +17,9 @@
import argparse
import json
import os
import re
import numpy as np
import paddle
from paddleformers.transformers.model_utils import shard_checkpoint
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
@@ -46,6 +48,28 @@ def parse_args():
return parser.parse_args()
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(paddle.float32)
4
```
"""
if str(dtype) in {"paddle.bool", "bool"}:
return 1 / 8
if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
return 1
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def extract_mtp_weights(input_dir: str) -> dict:
"""
Load all MTP-related weights from safetensors files in input_dir.
@@ -103,6 +127,18 @@ def save_safetensors(state_dict: dict, output_dir: str):
logger.info(f"Saving shard: {save_path}")
safe_save_file(shard, save_path, metadata={"format": "np"})
# If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null
if len(shards) == 1:
logger.info("Generate index file for single shard")
weight_size = 0
for key, weight in shards["model.safetensors"].items():
weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype)
index = {
"metadata": {"total_size": int(weight_size)},
"weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()},
}
index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
with open(index_path, "w", encoding="utf-8") as f:
json.dump(index, f, indent=2)