mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Speculative Decoding][Cherry Pick]Update extract_mtp_weight script and optimize config (#5213)
* update extract_mtp_model * modify config usage
This commit is contained in:
@@ -1628,10 +1628,6 @@ class FDConfig:
|
||||
else:
|
||||
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
|
||||
|
||||
self.scheduler_config.max_chunk_len = (
|
||||
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
|
||||
)
|
||||
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
|
||||
|
||||
|
||||
@@ -270,7 +270,6 @@ class SchedulerConfig:
|
||||
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
|
||||
self.max_num_batched_tokens = 2048 # base token_num for text inputs
|
||||
self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs
|
||||
self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens
|
||||
self.max_num_seqs = 34
|
||||
self.splitwise_role = "mixed"
|
||||
self.config = None
|
||||
|
||||
@@ -355,7 +355,13 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
[
|
||||
self.fd_config.scheduler_config.max_num_batched_tokens
|
||||
+ self.fd_config.scheduler_config.max_extra_num_batched_tokens,
|
||||
self.model_config.hidden_size,
|
||||
],
|
||||
0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleformers.transformers.model_utils import shard_checkpoint
|
||||
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
@@ -46,6 +48,28 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(paddle.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if str(dtype) in {"paddle.bool", "bool"}:
|
||||
return 1 / 8
|
||||
if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
|
||||
return 1
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def extract_mtp_weights(input_dir: str) -> dict:
|
||||
"""
|
||||
Load all MTP-related weights from safetensors files in input_dir.
|
||||
@@ -103,6 +127,18 @@ def save_safetensors(state_dict: dict, output_dir: str):
|
||||
logger.info(f"Saving shard: {save_path}")
|
||||
safe_save_file(shard, save_path, metadata={"format": "np"})
|
||||
|
||||
# If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null
|
||||
if len(shards) == 1:
|
||||
logger.info("Generate index file for single shard")
|
||||
weight_size = 0
|
||||
for key, weight in shards["model.safetensors"].items():
|
||||
weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype)
|
||||
|
||||
index = {
|
||||
"metadata": {"total_size": int(weight_size)},
|
||||
"weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()},
|
||||
}
|
||||
|
||||
index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
|
||||
Reference in New Issue
Block a user