From 438c9f785a1fa7b444b2d42e099223d312bbc041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Mon, 8 Dec 2025 16:47:44 +0800 Subject: [PATCH] [BugFix] 0 not into cuda graph to save memory (#5426) --- fastdeploy/config.py | 3 --- .../custom_all_reduce/custom_all_reduce.py | 4 ++++ fastdeploy/worker/gpu_model_runner.py | 16 +++++++--------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 475b9f6ff..f1eb23852 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1577,9 +1577,6 @@ class FDConfig: self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape) self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape) - if self.parallel_config.use_ep: - self.graph_opt_config.cudagraph_capture_sizes += [0] - self.tokenizer = tokenizer self.ips = ips self.tool_parser = tool_parser diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py index dfbed094d..0c9be796c 100644 --- a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -207,6 +207,10 @@ class CustomAllreduce: def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]: """The main allreduce API that provides support for cuda graph.""" + + if input.shape[0] == 0: + return input + if self.capturing: lib = cuda_wrapper.CudaRTLibrary() stream = paddle.device.current_stream() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9b550f104..2a0248894 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1020,14 +1020,10 @@ class GPUModelRunner(ModelRunnerBase): """ # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 - if batch_size == 0: - # Note(ZKK): divided by 0 is invalid, here we give a input_length = 1 - input_length = 1 - else: - input_length = min( - num_tokens // (1 if capture_prefill else batch_size), - self.model_config.max_model_len - max_dec_len, - ) + input_length = min( + num_tokens // (1 if capture_prefill else batch_size), + self.model_config.max_model_len - max_dec_len, + ) # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. @@ -1490,7 +1486,9 @@ class GPUModelRunner(ModelRunnerBase): # When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph] self.forward_meta.step_use_cudagraph = ( - only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph + only_prefill_use_cudagraph + if self.cudagraph_only_prefill + else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0 ) # Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends