From 920df5be5a54cdab71f92d4a110e2165733d898e Mon Sep 17 00:00:00 2001 From: RAM Date: Fri, 17 Oct 2025 14:22:05 +0800 Subject: [PATCH] [Graph Optimization][Speculative Decoding] Fix the bug of CUDAGraph + MTP + EP (#4430) * Fix MTP dummy run bug * Target Model and Draft Model using the same flag * aovid moe bug in cudagraph padding * In mtp replace use_cudagraph as step_use_cudagraph --- custom_ops/gpu_ops/append_attention.cu | 6 ++-- fastdeploy/spec_decode/mtp.py | 48 +++++++++++--------------- fastdeploy/worker/gpu_model_runner.py | 8 +++-- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index fb325d51d..73194eea7 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -496,12 +496,12 @@ std::vector AppendAttention( paddle::Tensor fmha_out; if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - fmha_out = GetEmptyTensor( + fmha_out = paddle::zeros( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, paddle::DataType::INT8, qkv.place()); } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - fmha_out = GetEmptyTensor( + fmha_out = paddle::zeros( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, paddle::DataType::FLOAT8_E4M3FN, qkv.place()); @@ -509,7 +509,7 @@ std::vector AppendAttention( PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); } } else { - fmha_out = GetEmptyTensor( + fmha_out = paddle::zeros( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, dtype_id, qkv.place()); diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index fb7d32645..20731d832 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -79,7 +79,7 @@ class MTPProposer(Proposer): self._init_model_inputs() # CUDA Graph - self.use_cudagraph = False # self.graph_opt_config.use_cudagraph + self.use_cudagraph = False # TODO(gongshaotian): Use Target Model flag self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes @@ -117,6 +117,9 @@ class MTPProposer(Proposer): self.parallel_config.max_model_len - max_dec_len, ) + if self.fd_config.parallel_config.enable_expert_parallel: + input_length = min(input_length, 32) + block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num @@ -541,7 +544,7 @@ class MTPProposer(Proposer): self.model_inputs["not_need_stop"][0] = True self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer - def _initialize_forward_meta(self): + def _initialize_forward_meta(self, step_use_cudagraph: bool = False): """ Initialize forward meta and attention meta data """ @@ -569,23 +572,8 @@ class MTPProposer(Proposer): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - # Update Batch type for cuda graph - only_decode_batch = True - prefill_exists = None - - # Mix ep in single node - if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": - only_decode_batch_list = [] - prefill_exists = self.exist_prefill() - paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) - only_decode_batch = all(only_decode_batch_list) - self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" - - self.forward_meta.step_use_cudagraph = ( - self.use_cudagraph - and only_decode_batch - and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) - ) + # TODO(gongshaotian): Use CUDAGraph with Draft Model + self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.use_cudagraph def exist_prefill(self): """ @@ -671,9 +659,12 @@ class MTPProposer(Proposer): self.parallel_config.use_ep, ) - def _propose(self): + def _propose(self, step_use_cudagraph: bool = False): """ - Main process for MTP inference + Main process for MTP inference. + Args: + step_use_cudagraph: bool + Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. """ for substep in range(self.num_model_steps): if self.model_inputs["not_need_stop"]: @@ -697,7 +688,7 @@ class MTPProposer(Proposer): # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) + self.model_inputs["batch_id_per_token"][:] = -1 self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # for speculative decoding @@ -705,7 +696,8 @@ class MTPProposer(Proposer): self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) # Initialize forward meta data - self._initialize_forward_meta() + self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph) + self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) # Padding inputs for cuda graph self.padding_cudagraph_inputs() @@ -733,7 +725,7 @@ class MTPProposer(Proposer): previous_hidden_states=self.model_inputs["target_hidden_states"], forward_meta=self.forward_meta, ) - if self.use_cudagraph: + if self.forward_meta.step_use_cudagraph: model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( @@ -861,10 +853,10 @@ class MTPProposer(Proposer): self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() - def _run_impl(self, full_hidden_states): - """""" + def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False): + """Execute Draft Model""" self._prepare_inputs(full_hidden_states) - self._propose() + self._propose(step_use_cudagraph=step_use_cudagraph) self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() @@ -881,7 +873,7 @@ class MTPProposer(Proposer): # In init_attention_metadata, the decode buffer has already been cleared # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. - if self.use_cudagraph: + if self.forward_meta.step_use_cudagraph: self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6e281fd5c..8c6d7ce46 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1253,7 +1253,9 @@ class GPUModelRunner(ModelRunnerBase): if self.speculative_decoding: if self.speculative_method == "mtp": - self.proposer.run(full_hidden_states=model_output) + self.proposer.run( + full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + ) else: self.proposer.run(share_inputs=self.share_inputs) @@ -1600,7 +1602,9 @@ class GPUModelRunner(ModelRunnerBase): # 6. Speculative decode if self.speculative_decoding: if self.speculative_method == "mtp": - self.proposer.run(full_hidden_states=model_output) + self.proposer.run( + full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + ) else: self.proposer.run(share_inputs=self.share_inputs)