[Metax] fix release2.4 and support cudagraph (#5547)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

Co-authored-by: xiaozude <xiaozude@outlook.com>
This commit is contained in:
zhang-chenyi
2025-12-15 14:23:33 +08:00
committed by GitHub
parent 4bd991aa17
commit 77f8ba06e7
5 changed files with 85 additions and 126 deletions

View File

@@ -1743,7 +1743,7 @@ class FDConfig:
logger.info(
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
)
if not current_platform.is_cuda():
if not current_platform.is_cuda() and not current_platform.is_maca():
self.graph_opt_config.use_cudagraph = False
logger.info("CUDAGraph currently only support on GPU!")
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:

View File

@@ -64,6 +64,8 @@ class FlashAttentionMetadata(AttentionMetadata):
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
seq_lens_dec: paddle.Tensor = None
block_table_dec: paddle.Tensor = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
@@ -135,6 +137,12 @@ class FlashAttentionBackend(AttentionBackend):
shape=[max_num_seqs, 1, 1, self.head_dim],
dtype=self.dtype,
)
self.attention_metadata.seq_lens_dec = paddle.empty(
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype="int32"
)
self.attention_metadata.block_table_dec = paddle.empty(
shape=[fd_config.scheduler_config.max_num_seqs, self.head_dim], dtype="int32"
)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -229,8 +237,9 @@ class FlashAttentionBackend(AttentionBackend):
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
self.attention_metadata.seq_lens_dec.copy_(forward_meta.seq_lens_decoder[self.batch_ids_decode, 0])
self.attention_metadata.block_table_dec.copy_(forward_meta.block_tables[self.batch_ids_decode, :])
# update prefilling rope
self.update_rotary_embs_prefill(forward_meta)
# update decoding rope
@@ -296,13 +305,18 @@ class FlashAttentionBackend(AttentionBackend):
bs = self.batch_ids_decode.shape[0]
if self.enable_mm:
index = paddle.concat(
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
[self.batch_ids_decode.view([-1, 1]), self.attention_metadata.seq_lens_dec.to("int64").view([-1, 1])],
axis=1,
)
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
else:
rot_cos = paddle.gather(forward_meta.rotary_embs[0, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
rot_sin = paddle.gather(forward_meta.rotary_embs[1, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
rot_cos = paddle.gather(
forward_meta.rotary_embs[0, 0, :, 0, :], self.attention_metadata.seq_lens_dec
).view([bs, 1, 1, -1])
rot_sin = paddle.gather(
forward_meta.rotary_embs[1, 0, :, 0, :], self.attention_metadata.seq_lens_dec
).view([bs, 1, 1, -1])
self.attention_metadata.rotary_cos_decode[:bs].copy_(
paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype)
)
@@ -476,8 +490,8 @@ class FlashAttentionBackend(AttentionBackend):
q,
forward_meta.caches[k_cache_id],
forward_meta.caches[v_cache_id],
self.seq_lens_dec,
self.block_table_dec,
self.attention_metadata.seq_lens_dec,
self.attention_metadata.block_table_dec,
k,
v,
rotary_cos=None,

View File

@@ -221,12 +221,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
"""
Calculate kv cache shape for MLA
"""
return (
max_num_blocks,
1,
self.block_size,
self.kv_lora_rank + self.qk_rope_head_dim,
)
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
value_cache_shape = []
return key_cache_shape, value_cache_shape
def compute_flash_mla(
self,

View File

@@ -15,6 +15,7 @@
"""
import os
from typing import Callable
import paddle
from paddle import nn
@@ -66,25 +67,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias)
layer.down_proj_bias.set_value(stacked_down_proj_bias)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
"""
raise NotImplementedError
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -96,6 +84,7 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -107,70 +96,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
"""
Paddle Cutlass compute Fused MoE.
"""
if layer.topk_method == "noaux_tc":
gate_out = gate(x.cast("float32"))
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
) = moe_expert_dispatch(
x,
gate_out,
layer.top_k,
False,
True,
)
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_idx,
None,
False,
1.0,
)
else:
raise NotImplementedError
fused_moe_out = fused_expert_moe(
x,
gate.weight,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
None,
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
"weight_only_int8",
layer.top_k,
True,
False,
)
return fused_moe_out
raise NotImplementedError
class MetaxCutlassMoEMethod(MoEMethodBase):
@@ -189,35 +120,12 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
layer.down_proj_weight.set_value(stacked_down_proj_weights)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
"""
return moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
expert_idx_per_token, # expert_idx_per_token: only for w4a8
self.moe_quant_type,
)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -229,6 +137,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -240,6 +149,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -282,7 +192,17 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
ffn_out = moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
expert_idx_per_token, # expert_idx_per_token: only for w4a8
self.moe_quant_type,
)
fused_moe_out = moe_expert_reduce(
ffn_out,

View File

@@ -104,11 +104,18 @@ class MetaxModelRunner(ModelRunnerBase):
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size
self.max_logprobs = (
self.ori_vocab_size if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs
)
self.max_logprobs = None
if self.enable_logprob:
self.max_logprobs = (
self.ori_vocab_size
if fd_config.model_config.max_logprobs == -1
else fd_config.model_config.max_logprobs
)
self.temp_scaled_logprobs = True
self.top_p_normalized_logprobs = True
self.prompt_logprobs_reqs: dict[str, Request] = {}
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}
self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)]
# VL model config:
if self.enable_mm:
@@ -640,6 +647,7 @@ class MetaxModelRunner(ModelRunnerBase):
# pooling model request.sampling_params is None
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
self.forward_batch_reqs_list[idx] = request
has_prefill_task = True
# Routing Replay
@@ -672,6 +680,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.prompt_logprobs_reqs.pop(request.request_id, None)
self.in_progress_prompt_logprobs.pop(request.request_id, None)
self.forward_batch_reqs_list[idx] = None
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
@@ -996,14 +1005,10 @@ class MetaxModelRunner(ModelRunnerBase):
"""
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
max_dec_len = expected_decode_len + 1
if batch_size == 0:
# Note(ZKK): divided by 0 is invalid, here we give a input_length = 1
input_length = 1
else:
input_length = min(
num_tokens // (1 if capture_prefill else batch_size),
self.model_config.max_model_len - max_dec_len,
)
input_length = min(
num_tokens // (1 if capture_prefill else batch_size),
self.model_config.max_model_len - max_dec_len,
)
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
@@ -1321,6 +1326,24 @@ class MetaxModelRunner(ModelRunnerBase):
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
)
logprobs_reqs = [
req
for req in self.forward_batch_reqs_list
if req is not None and req.sampling_params is not None and req.sampling_params.logprobs is not None
]
if len(logprobs_reqs):
self.max_logprobs = max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs)
self.top_p_normalized_logprobs = any(
req.sampling_params.top_p_normalized_logprobs for req in logprobs_reqs
)
else:
self.max_logprobs = None
# Remove padding
(
@@ -1376,9 +1399,11 @@ class MetaxModelRunner(ModelRunnerBase):
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=self.max_logprobs if self.enable_logprob else None,
max_num_logprobs=self.max_logprobs,
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
temp_scaled_logprobs_flag=self.temp_scaled_logprobs,
top_p_normalized_logprobs_flag=self.top_p_normalized_logprobs,
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
logits_processors=self.share_inputs["logits_processors"],
@@ -1466,7 +1491,9 @@ class MetaxModelRunner(ModelRunnerBase):
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
self.forward_meta.step_use_cudagraph = (
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
only_prefill_use_cudagraph
if self.cudagraph_only_prefill
else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0
)
# Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends
@@ -2634,6 +2661,7 @@ class MetaxModelRunner(ModelRunnerBase):
# prompt_logprobs
self.prompt_logprobs_reqs.clear()
self.in_progress_prompt_logprobs.clear()
self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)]
def update_parameters(self, pid):
"""Dynamic model loader use to update parameters use for RL"""