diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 1a7204044..8d9aa9ab0 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1628,10 +1628,6 @@ class FDConfig: else: self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len - self.scheduler_config.max_chunk_len = ( - self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens - ) - if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04) diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index 3be17c48c..6bce91e81 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -270,7 +270,6 @@ class SchedulerConfig: self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler self.max_num_batched_tokens = 2048 # base token_num for text inputs self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs - self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens self.max_num_seqs = 34 self.splitwise_role = "mixed" self.config = None diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index d7648c3e7..321ef5a1b 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -355,7 +355,13 @@ class MTPProposer(Proposer): self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["target_hidden_states"] = paddle.full( - [self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16" + [ + self.fd_config.scheduler_config.max_num_batched_tokens + + self.fd_config.scheduler_config.max_extra_num_batched_tokens, + self.model_config.hidden_size, + ], + 0, + dtype="bfloat16", ) tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) diff --git a/scripts/extract_mtp_weight_from_safetensor.py b/scripts/extract_mtp_weight_from_safetensor.py index 1ac1fcfa5..cc5482b4f 100644 --- a/scripts/extract_mtp_weight_from_safetensor.py +++ b/scripts/extract_mtp_weight_from_safetensor.py @@ -17,7 +17,9 @@ import argparse import json import os +import re +import numpy as np import paddle from paddleformers.transformers.model_utils import shard_checkpoint from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME @@ -46,6 +48,28 @@ def parse_args(): return parser.parse_args() +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(paddle.float32) + 4 + ``` + """ + if str(dtype) in {"paddle.bool", "bool"}: + return 1 / 8 + if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}: + return 1 + bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + def extract_mtp_weights(input_dir: str) -> dict: """ Load all MTP-related weights from safetensors files in input_dir. @@ -103,6 +127,18 @@ def save_safetensors(state_dict: dict, output_dir: str): logger.info(f"Saving shard: {save_path}") safe_save_file(shard, save_path, metadata={"format": "np"}) + # If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null + if len(shards) == 1: + logger.info("Generate index file for single shard") + weight_size = 0 + for key, weight in shards["model.safetensors"].items(): + weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype) + + index = { + "metadata": {"total_size": int(weight_size)}, + "weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()}, + } + index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME) with open(index_path, "w", encoding="utf-8") as f: json.dump(index, f, indent=2)