[Metax] optimize flash mla (#4915)

This commit is contained in:
xiaozude
2025-11-12 16:43:46 +08:00
committed by GitHub
parent 9d9f5df8d0
commit c45b3ccb52
5 changed files with 37 additions and 38 deletions

View File

@@ -91,7 +91,7 @@ void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
const int eos_token_id_len = eos_token_ids.shape()[0];
limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
limit_thinking_content_length_kernel_v1<<<1, 1024, 0, next_tokens.stream()>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),

View File

@@ -95,7 +95,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens,
const int64_t think_end_id,
const int64_t line_break_id) {
const int batch_size = next_tokens.shape()[0];
limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
limit_thinking_content_length_kernel_v2<<<1, 1024, 0, next_tokens.stream()>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),

View File

@@ -46,8 +46,8 @@ def flash_attn_unpadded_func(
v: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: Union[int, float],
max_seqlen_k: Union[int, float],
max_seqlen_q: int,
max_seqlen_k: int,
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
softmax_scale: float = 1.0,
@@ -57,9 +57,6 @@ def flash_attn_unpadded_func(
is_test: bool = True,
rng_name: str = "",
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64")
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64")
outputs = paddle._C_ops.flash_attn_unpadded(
q,
k,
@@ -68,8 +65,8 @@ def flash_attn_unpadded_func(
cu_seqlens_k,
fixed_seed_offset,
attn_mask,
max_seqlen_q_t,
max_seqlen_k_t,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
dropout,
causal,

View File

@@ -183,7 +183,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
)
# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item()
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
@@ -192,6 +192,20 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.attention_metadata: AttentionMetadata = metadata
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
non_zero_index = seq_lens_this_time.nonzero().flatten()
seq_lens_decoder = seq_lens_decoder[non_zero_index]
seq_lens_this_time = seq_lens_this_time[non_zero_index]
self.seq_lens_this_time = list(seq_lens_this_time.cpu())
self.seq_lens_this_time_max = max(self.seq_lens_this_time)
self.seq_lens_this_time_min = min(self.seq_lens_this_time)
self.seq_lens = seq_lens_decoder + seq_lens_this_time
self.block_tables = forward_meta.block_tables[non_zero_index]
paddle.device.empty_cache()
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
@@ -221,45 +235,34 @@ class MetaxMLAAttentionBackend(AttentionBackend):
assert latent_cache is not None
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
non_zero_index = paddle.nonzero(seq_lens_this_time).flatten()
seq_lens_decoder = seq_lens_decoder[non_zero_index]
seq_lens_this_time = seq_lens_this_time[non_zero_index]
latent_cache = latent_cache.transpose([0, 2, 1, 3])
block_tables = self.attention_metadata.block_tables[non_zero_index]
seq_lens = seq_lens_decoder + seq_lens_this_time
batch_size = block_tables.shape[0]
seq_len_q = seq_lens_this_time.max()
seq_len_q = self.seq_lens_this_time_max
num_heads_q = self.num_heads
num_heads_kv = latent_cache.shape[2]
head_dim_v = self.kv_lora_rank
head_dim_qk = self.kv_lora_rank + self.qk_rope_head_dim
if seq_len_q != seq_lens_this_time.min():
x = query.split(list(seq_lens_this_time), axis=0)
x = [paddle.concat([xi, paddle.zeros((seq_len_q - xi.shape[0], xi.shape[1]))], axis=0) for xi in x]
query = paddle.to_tensor(x)
if seq_len_q != self.seq_lens_this_time_min:
query = paddle.stack(
[
paddle.concat([x, paddle.zeros((seq_len_q - x.shape[0], x.shape[1]), dtype=x.dtype)])
for x in paddle.split(query, self.seq_lens_this_time)
]
)
query = query.reshape([batch_size, seq_len_q, num_heads_q, head_dim_qk])
query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk])
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
)
if tile_scheduler_metadata.shape[0] == 0:
print(f"seq_lens: {seq_lens}")
print(f"seq_len_q: {seq_len_q}")
print(f"num_heads_q: {num_heads_q}")
print(f"num_heads_kv: {num_heads_kv}")
assert tile_scheduler_metadata.shape[0] != 0
assert tile_scheduler_metadata.shape[0] != 0
out = flash_mla_with_kvcache(
query,
latent_cache,
block_tables,
seq_lens,
self.block_tables,
self.seq_lens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
@@ -267,9 +270,8 @@ class MetaxMLAAttentionBackend(AttentionBackend):
causal=self.causal,
)[0]
if seq_len_q != seq_lens_this_time.min():
x = [xi.split([num, seq_len_q - num], axis=0)[0] for xi, num in zip(out, seq_lens_this_time)]
out = paddle.concat(x, axis=0)
if seq_len_q != self.seq_lens_this_time_min:
out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)])
else:
out = out.reshape([-1, num_heads_q, head_dim_v])

View File

@@ -728,7 +728,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
current_total_tokens = paddle.sum(seq_lens_this_time)
current_total_tokens = forward_meta.ids_remove_padding.shape[0]
position_ids = self.position_ids_buffer[:current_total_tokens]
mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]