From 01d758666160846ea613b960c93f9cc811fcb262 Mon Sep 17 00:00:00 2001 From: Longzhi Wang <583087864@qq.com> Date: Mon, 4 Aug 2025 18:06:18 +0800 Subject: [PATCH] [Bug fix] Fix cudagraph when use ep. (#3130) * fix cudagraph when use ep * fix typo * reduce full length to adapt large bsz such 128/256 --- fastdeploy/worker/gpu_model_runner.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4b67b595e..0b4b0b4a7 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -525,6 +525,12 @@ class GPUModelRunner(ModelRunnerBase): num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len, ) + + # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. + # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. + if self.fd_config.parallel_config.enable_expert_parallel: + full_length = min(full_length, 32) + input_length = int(full_length * self.cache_config.kv_cache_ratio) block_num = ( input_length + self.cache_config.block_size - 1