From c45b3ccb52bbdde407ad872236d25737a5af91bc Mon Sep 17 00:00:00 2001 From: xiaozude Date: Wed, 12 Nov 2025 16:43:46 +0800 Subject: [PATCH] [Metax] optimize flash mla (#4915) --- .../limit_thinking_content_length_v1.cu | 2 +- .../limit_thinking_content_length_v2.cu | 2 +- .../attention/flash_attention_interface.py | 11 ++-- .../metax/attention/mla_attn_metax_backend.py | 58 ++++++++++--------- .../model_executor/models/deepseek_v3.py | 2 +- 5 files changed, 37 insertions(+), 38 deletions(-) diff --git a/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu b/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu index 8631ecb7d..45bf8f704 100644 --- a/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu +++ b/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu @@ -91,7 +91,7 @@ void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens, const int64_t think_end_id) { const int batch_size = next_tokens.shape()[0]; const int eos_token_id_len = eos_token_ids.shape()[0]; - limit_thinking_content_length_kernel_v1<<<1, 1024>>>( + limit_thinking_content_length_kernel_v1<<<1, 1024, 0, next_tokens.stream()>>>( const_cast(next_tokens.data()), max_think_lens.data(), step_idx.data(), diff --git a/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu b/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu index d2f0f513b..ea5f8c9c4 100644 --- a/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu +++ b/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu @@ -95,7 +95,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens, const int64_t think_end_id, const int64_t line_break_id) { const int batch_size = next_tokens.shape()[0]; - limit_thinking_content_length_kernel_v2<<<1, 1024>>>( + limit_thinking_content_length_kernel_v2<<<1, 1024, 0, next_tokens.stream()>>>( const_cast(next_tokens.data()), max_think_lens.data(), step_idx.data(), diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py index c1480170e..93d640493 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py @@ -46,8 +46,8 @@ def flash_attn_unpadded_func( v: Tensor, cu_seqlens_q: Tensor, cu_seqlens_k: Tensor, - max_seqlen_q: Union[int, float], - max_seqlen_k: Union[int, float], + max_seqlen_q: int, + max_seqlen_k: int, fixed_seed_offset: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, softmax_scale: float = 1.0, @@ -57,9 +57,6 @@ def flash_attn_unpadded_func( is_test: bool = True, rng_name: str = "", ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64") - max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64") - outputs = paddle._C_ops.flash_attn_unpadded( q, k, @@ -68,8 +65,8 @@ def flash_attn_unpadded_func( cu_seqlens_k, fixed_seed_offset, attn_mask, - max_seqlen_q_t, - max_seqlen_k_t, + max_seqlen_q, + max_seqlen_k, softmax_scale, dropout, causal, diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index 13568b632..ff1bce8bd 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -183,7 +183,7 @@ class MetaxMLAAttentionBackend(AttentionBackend): ) # MLA - metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] + metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item() metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8] @@ -192,6 +192,20 @@ class MetaxMLAAttentionBackend(AttentionBackend): self.attention_metadata: AttentionMetadata = metadata + seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1) + seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1) + non_zero_index = seq_lens_this_time.nonzero().flatten() + seq_lens_decoder = seq_lens_decoder[non_zero_index] + seq_lens_this_time = seq_lens_this_time[non_zero_index] + + self.seq_lens_this_time = list(seq_lens_this_time.cpu()) + self.seq_lens_this_time_max = max(self.seq_lens_this_time) + self.seq_lens_this_time_min = min(self.seq_lens_this_time) + self.seq_lens = seq_lens_decoder + seq_lens_this_time + self.block_tables = forward_meta.block_tables[non_zero_index] + + paddle.device.empty_cache() + def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" return self.attention_metadata @@ -221,45 +235,34 @@ class MetaxMLAAttentionBackend(AttentionBackend): assert latent_cache is not None - seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1) - seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1) - non_zero_index = paddle.nonzero(seq_lens_this_time).flatten() - seq_lens_decoder = seq_lens_decoder[non_zero_index] - seq_lens_this_time = seq_lens_this_time[non_zero_index] - latent_cache = latent_cache.transpose([0, 2, 1, 3]) - block_tables = self.attention_metadata.block_tables[non_zero_index] - seq_lens = seq_lens_decoder + seq_lens_this_time - batch_size = block_tables.shape[0] - seq_len_q = seq_lens_this_time.max() + seq_len_q = self.seq_lens_this_time_max num_heads_q = self.num_heads num_heads_kv = latent_cache.shape[2] head_dim_v = self.kv_lora_rank head_dim_qk = self.kv_lora_rank + self.qk_rope_head_dim - if seq_len_q != seq_lens_this_time.min(): - x = query.split(list(seq_lens_this_time), axis=0) - x = [paddle.concat([xi, paddle.zeros((seq_len_q - xi.shape[0], xi.shape[1]))], axis=0) for xi in x] - query = paddle.to_tensor(x) + if seq_len_q != self.seq_lens_this_time_min: + query = paddle.stack( + [ + paddle.concat([x, paddle.zeros((seq_len_q - x.shape[0], x.shape[1]), dtype=x.dtype)]) + for x in paddle.split(query, self.seq_lens_this_time) + ] + ) - query = query.reshape([batch_size, seq_len_q, num_heads_q, head_dim_qk]) + query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk]) tile_scheduler_metadata, num_splits = get_mla_metadata( - seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv + self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv ) - if tile_scheduler_metadata.shape[0] == 0: - print(f"seq_lens: {seq_lens}") - print(f"seq_len_q: {seq_len_q}") - print(f"num_heads_q: {num_heads_q}") - print(f"num_heads_kv: {num_heads_kv}") - assert tile_scheduler_metadata.shape[0] != 0 + assert tile_scheduler_metadata.shape[0] != 0 out = flash_mla_with_kvcache( query, latent_cache, - block_tables, - seq_lens, + self.block_tables, + self.seq_lens, head_dim_v, tile_scheduler_metadata, num_splits, @@ -267,9 +270,8 @@ class MetaxMLAAttentionBackend(AttentionBackend): causal=self.causal, )[0] - if seq_len_q != seq_lens_this_time.min(): - x = [xi.split([num, seq_len_q - num], axis=0)[0] for xi, num in zip(out, seq_lens_this_time)] - out = paddle.concat(x, axis=0) + if seq_len_q != self.seq_lens_this_time_min: + out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)]) else: out = out.reshape([-1, num_heads_q, head_dim_v]) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 07aaa28a1..bcace8dd5 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -728,7 +728,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): seq_lens_decoder = forward_meta.seq_lens_decoder seq_lens_this_time = forward_meta.seq_lens_this_time - current_total_tokens = paddle.sum(seq_lens_this_time) + current_total_tokens = forward_meta.ids_remove_padding.shape[0] position_ids = self.position_ids_buffer[:current_total_tokens] mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]