mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimization] Reduce memory allocate for cudaGraph (#4838)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* optimize memory allocate * rename env
This commit is contained in:
@@ -298,6 +298,9 @@ class ParallelConfig:
|
||||
self.use_internode_ll_two_stage: bool = False
|
||||
|
||||
self.max_num_batched_tokens: int = 2048
|
||||
|
||||
# Max chunk len in infer. Must be overrided
|
||||
self.max_chunk_len: int = 0
|
||||
# splitwise role
|
||||
self.splitwise_role: str = "mixed"
|
||||
# guided decoding backend
|
||||
@@ -364,6 +367,10 @@ class ParallelConfig:
|
||||
)
|
||||
dist.collective._set_custom_gid(None)
|
||||
|
||||
def postprocess(self):
|
||||
# 2048 is extra buffer for decoding. It should be more accurate in future
|
||||
self.max_chunk_len = int(envs.FD_MAX_EXTRA_NUM_BATCHED_TOKENS) + self.max_num_batched_tokens + 2048
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
@@ -1276,6 +1283,8 @@ class FDConfig:
|
||||
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
||||
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
|
||||
|
||||
self.parallel_config.postprocess()
|
||||
|
||||
if self.guided_decoding_backend == "auto":
|
||||
if self.model_config.enable_mm:
|
||||
self.guided_decoding_backend = "off"
|
||||
|
||||
@@ -114,6 +114,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"ENCODE_FEATURE_BOS_AK": lambda: os.getenv("ENCODE_FEATURE_BOS_AK"),
|
||||
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
|
||||
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
|
||||
"FD_MAX_EXTRA_NUM_BATCHED_TOKENS": lambda: int(os.getenv("FD_MAX_EXTRA_NUM_BATCHED_TOKENS", "16384")),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -329,7 +329,7 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
[self.parallel_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
|
||||
Reference in New Issue
Block a user