From 77f8ba06e7c2520146b551128fb40467cdcb84cd Mon Sep 17 00:00:00 2001 From: zhang-chenyi <74278535+zhang-chenyi@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:23:33 +0800 Subject: [PATCH] [Metax] fix release2.4 and support cudagraph (#5547) Co-authored-by: xiaozude --- fastdeploy/config.py | 2 +- .../metax/attention/flash_attn_backend.py | 28 +++-- .../metax/attention/mla_attn_metax_backend.py | 9 +- .../moe/fused_moe_cutlass_metax_backend.py | 118 +++--------------- fastdeploy/worker/metax_model_runner.py | 54 ++++++-- 5 files changed, 85 insertions(+), 126 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index fd62598e0..974860e03 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1743,7 +1743,7 @@ class FDConfig: logger.info( "Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!" ) - if not current_platform.is_cuda(): + if not current_platform.is_cuda() and not current_platform.is_maca(): self.graph_opt_config.use_cudagraph = False logger.info("CUDAGraph currently only support on GPU!") if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index 646997e31..d39f6e615 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -64,6 +64,8 @@ class FlashAttentionMetadata(AttentionMetadata): encoder_block_shape_q: int = -1 decoder_block_shape_q: int = -1 _fuse_kernel_compute_dtype: str = "bf16" + seq_lens_dec: paddle.Tensor = None + block_table_dec: paddle.Tensor = None # pd_disaggregation kv_signal_metadata: Optional[paddle.Tensor] = None @@ -135,6 +137,12 @@ class FlashAttentionBackend(AttentionBackend): shape=[max_num_seqs, 1, 1, self.head_dim], dtype=self.dtype, ) + self.attention_metadata.seq_lens_dec = paddle.empty( + shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype="int32" + ) + self.attention_metadata.block_table_dec = paddle.empty( + shape=[fd_config.scheduler_config.max_num_seqs, self.head_dim], dtype="int32" + ) def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" @@ -229,8 +237,9 @@ class FlashAttentionBackend(AttentionBackend): self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"]) self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"]) - self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0] - self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :] + self.attention_metadata.seq_lens_dec.copy_(forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]) + self.attention_metadata.block_table_dec.copy_(forward_meta.block_tables[self.batch_ids_decode, :]) + # update prefilling rope self.update_rotary_embs_prefill(forward_meta) # update decoding rope @@ -296,13 +305,18 @@ class FlashAttentionBackend(AttentionBackend): bs = self.batch_ids_decode.shape[0] if self.enable_mm: index = paddle.concat( - [self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1 + [self.batch_ids_decode.view([-1, 1]), self.attention_metadata.seq_lens_dec.to("int64").view([-1, 1])], + axis=1, ) rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1]) rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1]) else: - rot_cos = paddle.gather(forward_meta.rotary_embs[0, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1]) - rot_sin = paddle.gather(forward_meta.rotary_embs[1, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1]) + rot_cos = paddle.gather( + forward_meta.rotary_embs[0, 0, :, 0, :], self.attention_metadata.seq_lens_dec + ).view([bs, 1, 1, -1]) + rot_sin = paddle.gather( + forward_meta.rotary_embs[1, 0, :, 0, :], self.attention_metadata.seq_lens_dec + ).view([bs, 1, 1, -1]) self.attention_metadata.rotary_cos_decode[:bs].copy_( paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype) ) @@ -476,8 +490,8 @@ class FlashAttentionBackend(AttentionBackend): q, forward_meta.caches[k_cache_id], forward_meta.caches[v_cache_id], - self.seq_lens_dec, - self.block_table_dec, + self.attention_metadata.seq_lens_dec, + self.attention_metadata.block_table_dec, k, v, rotary_cos=None, diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index cd7649ae7..6d17d3c38 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -221,12 +221,9 @@ class MetaxMLAAttentionBackend(AttentionBackend): """ Calculate kv cache shape for MLA """ - return ( - max_num_blocks, - 1, - self.block_size, - self.kv_lora_rank + self.qk_rope_head_dim, - ) + key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim] + value_cache_shape = [] + return key_cache_shape, value_cache_shape def compute_flash_mla( self, diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py index 9b65e073c..742d6e60f 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py @@ -15,6 +15,7 @@ """ import os +from typing import Callable import paddle from paddle import nn @@ -66,25 +67,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias) layer.down_proj_bias.set_value(stacked_down_proj_bias) - def compute_ffn( - self, - layer: nn.Layer, - permute_input: paddle.Tensor, - token_nums_per_expert: paddle.Tensor, - expert_idx_per_token: paddle.Tensor, - used_in_ep_low_latency: bool = False, - estimate_total_token_nums: int = -1, - ): - """ - Paddle Cutlass compute Fused MoE. - """ - raise NotImplementedError - def apply_ep_prefill( self, layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -96,6 +84,7 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -107,70 +96,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. """ - """ - Paddle Cutlass compute Fused MoE. - """ - if layer.topk_method == "noaux_tc": - gate_out = gate(x.cast("float32")) - - gate_out, topk_weights, topk_idx = get_moe_scores( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - - ( - permute_input, - token_nums_per_expert, - permute_indices_per_token, - topk_weights, - topk_idx, - ) = moe_expert_dispatch( - x, - gate_out, - layer.top_k, - False, - True, - ) - - ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None) - - fused_moe_out = moe_expert_reduce( - ffn_out, - topk_weights, - permute_indices_per_token, - topk_idx, - None, - False, - 1.0, - ) - else: - raise NotImplementedError - - fused_moe_out = fused_expert_moe( - x, - gate.weight, - getattr(layer, self.added_weight_attrs[0]), - getattr(layer, self.added_weight_attrs[1]), - None, - (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), - None, - (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), - "weight_only_int8", - layer.top_k, - True, - False, - ) - - return fused_moe_out + raise NotImplementedError class MetaxCutlassMoEMethod(MoEMethodBase): @@ -189,35 +120,12 @@ class MetaxCutlassMoEMethod(MoEMethodBase): layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights) layer.down_proj_weight.set_value(stacked_down_proj_weights) - def compute_ffn( - self, - layer: nn.Layer, - permute_input: paddle.Tensor, - token_nums_per_expert: paddle.Tensor, - expert_idx_per_token: paddle.Tensor, - used_in_ep_low_latency: bool = False, - estimate_total_token_nums: int = -1, - ): - """ - Paddle Cutlass compute Fused MoE. - """ - return moe_expert_ffn( - permute_input, - token_nums_per_expert, - getattr(layer, self.added_weight_attrs[0]), - getattr(layer, self.added_weight_attrs[1]), - None, - (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), - (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), - expert_idx_per_token, # expert_idx_per_token: only for w4a8 - self.moe_quant_type, - ) - def apply_ep_prefill( self, layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -229,6 +137,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -240,6 +149,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -282,7 +192,17 @@ class MetaxCutlassMoEMethod(MoEMethodBase): else: expert_idx_per_token = expert_idx_per_token.cast("int64") - ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token) + ffn_out = moe_expert_ffn( + permute_input, + token_nums_per_expert, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), + None, + (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), + (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), + expert_idx_per_token, # expert_idx_per_token: only for w4a8 + self.moe_quant_type, + ) fused_moe_out = moe_expert_reduce( ffn_out, diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 4d27abdc5..26aa88def 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -104,11 +104,18 @@ class MetaxModelRunner(ModelRunnerBase): self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size - self.max_logprobs = ( - self.ori_vocab_size if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs - ) + self.max_logprobs = None + if self.enable_logprob: + self.max_logprobs = ( + self.ori_vocab_size + if fd_config.model_config.max_logprobs == -1 + else fd_config.model_config.max_logprobs + ) + self.temp_scaled_logprobs = True + self.top_p_normalized_logprobs = True self.prompt_logprobs_reqs: dict[str, Request] = {} self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} + self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)] # VL model config: if self.enable_mm: @@ -640,6 +647,7 @@ class MetaxModelRunner(ModelRunnerBase): # pooling model request.sampling_params is None if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: self.prompt_logprobs_reqs[request.request_id] = request + self.forward_batch_reqs_list[idx] = request has_prefill_task = True # Routing Replay @@ -672,6 +680,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["is_block_step"][idx : idx + 1] = False self.prompt_logprobs_reqs.pop(request.request_id, None) self.in_progress_prompt_logprobs.pop(request.request_id, None) + self.forward_batch_reqs_list[idx] = None continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -996,14 +1005,10 @@ class MetaxModelRunner(ModelRunnerBase): """ # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 - if batch_size == 0: - # Note(ZKK): divided by 0 is invalid, here we give a input_length = 1 - input_length = 1 - else: - input_length = min( - num_tokens // (1 if capture_prefill else batch_size), - self.model_config.max_model_len - max_dec_len, - ) + input_length = min( + num_tokens // (1 if capture_prefill else batch_size), + self.model_config.max_model_len - max_dec_len, + ) # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. @@ -1321,6 +1326,24 @@ class MetaxModelRunner(ModelRunnerBase): self.cache_config.block_size, self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) + logprobs_reqs = [ + req + for req in self.forward_batch_reqs_list + if req is not None and req.sampling_params is not None and req.sampling_params.logprobs is not None + ] + if len(logprobs_reqs): + self.max_logprobs = max( + [ + self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs + for req in logprobs_reqs + ] + ) + self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs) + self.top_p_normalized_logprobs = any( + req.sampling_params.top_p_normalized_logprobs for req in logprobs_reqs + ) + else: + self.max_logprobs = None # Remove padding ( @@ -1376,9 +1399,11 @@ class MetaxModelRunner(ModelRunnerBase): min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], eos_token_ids=self.share_inputs["eos_token_id"], - max_num_logprobs=self.max_logprobs if self.enable_logprob else None, + max_num_logprobs=self.max_logprobs, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], + temp_scaled_logprobs_flag=self.temp_scaled_logprobs, + top_p_normalized_logprobs_flag=self.top_p_normalized_logprobs, temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], logits_processors=self.share_inputs["logits_processors"], @@ -1466,7 +1491,9 @@ class MetaxModelRunner(ModelRunnerBase): # When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph] self.forward_meta.step_use_cudagraph = ( - only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph + only_prefill_use_cudagraph + if self.cudagraph_only_prefill + else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0 ) # Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends @@ -2634,6 +2661,7 @@ class MetaxModelRunner(ModelRunnerBase): # prompt_logprobs self.prompt_logprobs_reqs.clear() self.in_progress_prompt_logprobs.clear() + self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)] def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL"""