diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 433ec85cb..26522044f 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -576,8 +576,9 @@ class GCUModelRunner(ModelRunnerBase): ) # Update Batch type for cuda graph - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch + self.forward_meta.step_use_cudagraph = self.use_cudagraph and ( + not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) + ) # Initialzie attention meta data for attn_backend in self.attn_backends: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 0be973530..3f9014dc2 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -793,16 +793,21 @@ class GPUModelRunner(ModelRunnerBase): # Update Batch type for cuda graph # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) - + only_decode_batch = True + prefill_exists = None # mix ep in single node if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": - is_decode_batch_list = [] - paddle.distributed.all_gather_object(is_decode_batch_list, is_decode_batch) - is_decode_batch = all(is_decode_batch_list) - self.fd_config.parallel_config.moe_phase.phase = "decode" if is_decode_batch else "prefill" + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" - self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch + self.forward_meta.step_use_cudagraph = ( + self.use_cudagraph + and only_decode_batch + and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + ) # Initialzie attention meta data for attn_backend in self.attn_backends: