diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 7fdc98d88..1c52f27d2 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -298,6 +298,9 @@ class ParallelConfig: self.use_internode_ll_two_stage: bool = False self.max_num_batched_tokens: int = 2048 + + # Max chunk len in infer. Must be overrided + self.max_chunk_len: int = 0 # splitwise role self.splitwise_role: str = "mixed" # guided decoding backend @@ -364,6 +367,10 @@ class ParallelConfig: ) dist.collective._set_custom_gid(None) + def postprocess(self): + # 2048 is extra buffer for decoding. It should be more accurate in future + self.max_chunk_len = int(envs.FD_MAX_EXTRA_NUM_BATCHED_TOKENS) + self.max_num_batched_tokens + 2048 + def print(self): """ print all config @@ -1276,6 +1283,8 @@ class FDConfig: self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs) self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size) + self.parallel_config.postprocess() + if self.guided_decoding_backend == "auto": if self.model_config.enable_mm: self.guided_decoding_backend = "off" diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index be1adcdb8..782fef741 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -114,6 +114,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "ENCODE_FEATURE_BOS_AK": lambda: os.getenv("ENCODE_FEATURE_BOS_AK"), "ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"), "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), + "FD_MAX_EXTRA_NUM_BATCHED_TOKENS": lambda: int(os.getenv("FD_MAX_EXTRA_NUM_BATCHED_TOKENS", "16384")), } diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index d4d3fc9f5..02a48d394 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -329,7 +329,7 @@ class MTPProposer(Proposer): self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["target_hidden_states"] = paddle.full( - [self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16" + [self.parallel_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16" ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))