[Optimization] Reduce memory allocate for cudaGraph (#4838)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* optimize memory allocate

* rename env
This commit is contained in:
freeliuzc
2025-11-06 13:32:47 +08:00
committed by GitHub
parent e0d98d00bc
commit bbae094cb9
3 changed files with 11 additions and 1 deletions

View File

@@ -298,6 +298,9 @@ class ParallelConfig:
self.use_internode_ll_two_stage: bool = False
self.max_num_batched_tokens: int = 2048
# Max chunk len in infer. Must be overrided
self.max_chunk_len: int = 0
# splitwise role
self.splitwise_role: str = "mixed"
# guided decoding backend
@@ -364,6 +367,10 @@ class ParallelConfig:
)
dist.collective._set_custom_gid(None)
def postprocess(self):
# 2048 is extra buffer for decoding. It should be more accurate in future
self.max_chunk_len = int(envs.FD_MAX_EXTRA_NUM_BATCHED_TOKENS) + self.max_num_batched_tokens + 2048
def print(self):
"""
print all config
@@ -1276,6 +1283,8 @@ class FDConfig:
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
self.parallel_config.postprocess()
if self.guided_decoding_backend == "auto":
if self.model_config.enable_mm:
self.guided_decoding_backend = "off"

View File

@@ -114,6 +114,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"ENCODE_FEATURE_BOS_AK": lambda: os.getenv("ENCODE_FEATURE_BOS_AK"),
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
"FD_MAX_EXTRA_NUM_BATCHED_TOKENS": lambda: int(os.getenv("FD_MAX_EXTRA_NUM_BATCHED_TOKENS", "16384")),
}

View File

@@ -329,7 +329,7 @@ class MTPProposer(Proposer):
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["target_hidden_states"] = paddle.full(
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
[self.parallel_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
)
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))