mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] fix release2.4 and support cudagraph (#5547)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Co-authored-by: xiaozude <xiaozude@outlook.com>
This commit is contained in:
@@ -1743,7 +1743,7 @@ class FDConfig:
|
||||
logger.info(
|
||||
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
|
||||
)
|
||||
if not current_platform.is_cuda():
|
||||
if not current_platform.is_cuda() and not current_platform.is_maca():
|
||||
self.graph_opt_config.use_cudagraph = False
|
||||
logger.info("CUDAGraph currently only support on GPU!")
|
||||
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
|
||||
|
||||
@@ -64,6 +64,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
seq_lens_dec: paddle.Tensor = None
|
||||
block_table_dec: paddle.Tensor = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
@@ -135,6 +137,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
shape=[max_num_seqs, 1, 1, self.head_dim],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.attention_metadata.seq_lens_dec = paddle.empty(
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype="int32"
|
||||
)
|
||||
self.attention_metadata.block_table_dec = paddle.empty(
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, self.head_dim], dtype="int32"
|
||||
)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
@@ -229,8 +237,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
|
||||
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
|
||||
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
|
||||
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
|
||||
self.attention_metadata.seq_lens_dec.copy_(forward_meta.seq_lens_decoder[self.batch_ids_decode, 0])
|
||||
self.attention_metadata.block_table_dec.copy_(forward_meta.block_tables[self.batch_ids_decode, :])
|
||||
|
||||
# update prefilling rope
|
||||
self.update_rotary_embs_prefill(forward_meta)
|
||||
# update decoding rope
|
||||
@@ -296,13 +305,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
bs = self.batch_ids_decode.shape[0]
|
||||
if self.enable_mm:
|
||||
index = paddle.concat(
|
||||
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
|
||||
[self.batch_ids_decode.view([-1, 1]), self.attention_metadata.seq_lens_dec.to("int64").view([-1, 1])],
|
||||
axis=1,
|
||||
)
|
||||
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
else:
|
||||
rot_cos = paddle.gather(forward_meta.rotary_embs[0, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather(forward_meta.rotary_embs[1, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
|
||||
rot_cos = paddle.gather(
|
||||
forward_meta.rotary_embs[0, 0, :, 0, :], self.attention_metadata.seq_lens_dec
|
||||
).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather(
|
||||
forward_meta.rotary_embs[1, 0, :, 0, :], self.attention_metadata.seq_lens_dec
|
||||
).view([bs, 1, 1, -1])
|
||||
self.attention_metadata.rotary_cos_decode[:bs].copy_(
|
||||
paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype)
|
||||
)
|
||||
@@ -476,8 +490,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
forward_meta.caches[k_cache_id],
|
||||
forward_meta.caches[v_cache_id],
|
||||
self.seq_lens_dec,
|
||||
self.block_table_dec,
|
||||
self.attention_metadata.seq_lens_dec,
|
||||
self.attention_metadata.block_table_dec,
|
||||
k,
|
||||
v,
|
||||
rotary_cos=None,
|
||||
|
||||
@@ -221,12 +221,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Calculate kv cache shape for MLA
|
||||
"""
|
||||
return (
|
||||
max_num_blocks,
|
||||
1,
|
||||
self.block_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
|
||||
value_cache_shape = []
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def compute_flash_mla(
|
||||
self,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -66,25 +67,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias)
|
||||
layer.down_proj_bias.set_value(stacked_down_proj_bias)
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input: paddle.Tensor,
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -96,6 +84,7 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -107,70 +96,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
gate_out, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.top_k,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
False,
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
fused_moe_out = fused_expert_moe(
|
||||
x,
|
||||
gate.weight,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
None,
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
"weight_only_int8",
|
||||
layer.top_k,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
return fused_moe_out
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
@@ -189,35 +120,12 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input: paddle.Tensor,
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
return moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
expert_idx_per_token, # expert_idx_per_token: only for w4a8
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -229,6 +137,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -240,6 +149,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -282,7 +192,17 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
|
||||
ffn_out = moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
expert_idx_per_token, # expert_idx_per_token: only for w4a8
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
|
||||
@@ -104,11 +104,18 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
|
||||
self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size
|
||||
self.max_logprobs = (
|
||||
self.ori_vocab_size if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs
|
||||
)
|
||||
self.max_logprobs = None
|
||||
if self.enable_logprob:
|
||||
self.max_logprobs = (
|
||||
self.ori_vocab_size
|
||||
if fd_config.model_config.max_logprobs == -1
|
||||
else fd_config.model_config.max_logprobs
|
||||
)
|
||||
self.temp_scaled_logprobs = True
|
||||
self.top_p_normalized_logprobs = True
|
||||
self.prompt_logprobs_reqs: dict[str, Request] = {}
|
||||
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}
|
||||
self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)]
|
||||
|
||||
# VL model config:
|
||||
if self.enable_mm:
|
||||
@@ -640,6 +647,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
# pooling model request.sampling_params is None
|
||||
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
|
||||
self.prompt_logprobs_reqs[request.request_id] = request
|
||||
self.forward_batch_reqs_list[idx] = request
|
||||
has_prefill_task = True
|
||||
|
||||
# Routing Replay
|
||||
@@ -672,6 +680,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
self.prompt_logprobs_reqs.pop(request.request_id, None)
|
||||
self.in_progress_prompt_logprobs.pop(request.request_id, None)
|
||||
self.forward_batch_reqs_list[idx] = None
|
||||
continue
|
||||
|
||||
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
|
||||
@@ -996,14 +1005,10 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
||||
max_dec_len = expected_decode_len + 1
|
||||
if batch_size == 0:
|
||||
# Note(ZKK): divided by 0 is invalid, here we give a input_length = 1
|
||||
input_length = 1
|
||||
else:
|
||||
input_length = min(
|
||||
num_tokens // (1 if capture_prefill else batch_size),
|
||||
self.model_config.max_model_len - max_dec_len,
|
||||
)
|
||||
input_length = min(
|
||||
num_tokens // (1 if capture_prefill else batch_size),
|
||||
self.model_config.max_model_len - max_dec_len,
|
||||
)
|
||||
|
||||
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
|
||||
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
|
||||
@@ -1321,6 +1326,24 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
||||
)
|
||||
logprobs_reqs = [
|
||||
req
|
||||
for req in self.forward_batch_reqs_list
|
||||
if req is not None and req.sampling_params is not None and req.sampling_params.logprobs is not None
|
||||
]
|
||||
if len(logprobs_reqs):
|
||||
self.max_logprobs = max(
|
||||
[
|
||||
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
|
||||
for req in logprobs_reqs
|
||||
]
|
||||
)
|
||||
self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs)
|
||||
self.top_p_normalized_logprobs = any(
|
||||
req.sampling_params.top_p_normalized_logprobs for req in logprobs_reqs
|
||||
)
|
||||
else:
|
||||
self.max_logprobs = None
|
||||
|
||||
# Remove padding
|
||||
(
|
||||
@@ -1376,9 +1399,11 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=self.max_logprobs if self.enable_logprob else None,
|
||||
max_num_logprobs=self.max_logprobs,
|
||||
enable_early_stop=self.enable_early_stop,
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
temp_scaled_logprobs_flag=self.temp_scaled_logprobs,
|
||||
top_p_normalized_logprobs_flag=self.top_p_normalized_logprobs,
|
||||
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
|
||||
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
|
||||
logits_processors=self.share_inputs["logits_processors"],
|
||||
@@ -1466,7 +1491,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
|
||||
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
|
||||
only_prefill_use_cudagraph
|
||||
if self.cudagraph_only_prefill
|
||||
else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0
|
||||
)
|
||||
|
||||
# Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends
|
||||
@@ -2634,6 +2661,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
# prompt_logprobs
|
||||
self.prompt_logprobs_reqs.clear()
|
||||
self.in_progress_prompt_logprobs.clear()
|
||||
self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)]
|
||||
|
||||
def update_parameters(self, pid):
|
||||
"""Dynamic model loader use to update parameters use for RL"""
|
||||
|
||||
Reference in New Issue
Block a user