mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] optimize flash mla (#4915)
This commit is contained in:
@@ -91,7 +91,7 @@ void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
|
||||
const int64_t think_end_id) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
|
||||
limit_thinking_content_length_kernel_v1<<<1, 1024, 0, next_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
|
||||
@@ -95,7 +95,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
|
||||
limit_thinking_content_length_kernel_v2<<<1, 1024, 0, next_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
|
||||
@@ -46,8 +46,8 @@ def flash_attn_unpadded_func(
|
||||
v: Tensor,
|
||||
cu_seqlens_q: Tensor,
|
||||
cu_seqlens_k: Tensor,
|
||||
max_seqlen_q: Union[int, float],
|
||||
max_seqlen_k: Union[int, float],
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
fixed_seed_offset: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
softmax_scale: float = 1.0,
|
||||
@@ -57,9 +57,6 @@ def flash_attn_unpadded_func(
|
||||
is_test: bool = True,
|
||||
rng_name: str = "",
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64")
|
||||
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64")
|
||||
|
||||
outputs = paddle._C_ops.flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
@@ -68,8 +65,8 @@ def flash_attn_unpadded_func(
|
||||
cu_seqlens_k,
|
||||
fixed_seed_offset,
|
||||
attn_mask,
|
||||
max_seqlen_q_t,
|
||||
max_seqlen_k_t,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
dropout,
|
||||
causal,
|
||||
|
||||
@@ -183,7 +183,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item()
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
|
||||
|
||||
@@ -192,6 +192,20 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
|
||||
non_zero_index = seq_lens_this_time.nonzero().flatten()
|
||||
seq_lens_decoder = seq_lens_decoder[non_zero_index]
|
||||
seq_lens_this_time = seq_lens_this_time[non_zero_index]
|
||||
|
||||
self.seq_lens_this_time = list(seq_lens_this_time.cpu())
|
||||
self.seq_lens_this_time_max = max(self.seq_lens_this_time)
|
||||
self.seq_lens_this_time_min = min(self.seq_lens_this_time)
|
||||
self.seq_lens = seq_lens_decoder + seq_lens_this_time
|
||||
self.block_tables = forward_meta.block_tables[non_zero_index]
|
||||
|
||||
paddle.device.empty_cache()
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
@@ -221,45 +235,34 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
|
||||
assert latent_cache is not None
|
||||
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
|
||||
non_zero_index = paddle.nonzero(seq_lens_this_time).flatten()
|
||||
seq_lens_decoder = seq_lens_decoder[non_zero_index]
|
||||
seq_lens_this_time = seq_lens_this_time[non_zero_index]
|
||||
|
||||
latent_cache = latent_cache.transpose([0, 2, 1, 3])
|
||||
block_tables = self.attention_metadata.block_tables[non_zero_index]
|
||||
seq_lens = seq_lens_decoder + seq_lens_this_time
|
||||
batch_size = block_tables.shape[0]
|
||||
seq_len_q = seq_lens_this_time.max()
|
||||
seq_len_q = self.seq_lens_this_time_max
|
||||
num_heads_q = self.num_heads
|
||||
num_heads_kv = latent_cache.shape[2]
|
||||
head_dim_v = self.kv_lora_rank
|
||||
head_dim_qk = self.kv_lora_rank + self.qk_rope_head_dim
|
||||
|
||||
if seq_len_q != seq_lens_this_time.min():
|
||||
x = query.split(list(seq_lens_this_time), axis=0)
|
||||
x = [paddle.concat([xi, paddle.zeros((seq_len_q - xi.shape[0], xi.shape[1]))], axis=0) for xi in x]
|
||||
query = paddle.to_tensor(x)
|
||||
if seq_len_q != self.seq_lens_this_time_min:
|
||||
query = paddle.stack(
|
||||
[
|
||||
paddle.concat([x, paddle.zeros((seq_len_q - x.shape[0], x.shape[1]), dtype=x.dtype)])
|
||||
for x in paddle.split(query, self.seq_lens_this_time)
|
||||
]
|
||||
)
|
||||
|
||||
query = query.reshape([batch_size, seq_len_q, num_heads_q, head_dim_qk])
|
||||
query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk])
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
|
||||
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
|
||||
)
|
||||
|
||||
if tile_scheduler_metadata.shape[0] == 0:
|
||||
print(f"seq_lens: {seq_lens}")
|
||||
print(f"seq_len_q: {seq_len_q}")
|
||||
print(f"num_heads_q: {num_heads_q}")
|
||||
print(f"num_heads_kv: {num_heads_kv}")
|
||||
assert tile_scheduler_metadata.shape[0] != 0
|
||||
assert tile_scheduler_metadata.shape[0] != 0
|
||||
|
||||
out = flash_mla_with_kvcache(
|
||||
query,
|
||||
latent_cache,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
self.block_tables,
|
||||
self.seq_lens,
|
||||
head_dim_v,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
@@ -267,9 +270,8 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
causal=self.causal,
|
||||
)[0]
|
||||
|
||||
if seq_len_q != seq_lens_this_time.min():
|
||||
x = [xi.split([num, seq_len_q - num], axis=0)[0] for xi, num in zip(out, seq_lens_this_time)]
|
||||
out = paddle.concat(x, axis=0)
|
||||
if seq_len_q != self.seq_lens_this_time_min:
|
||||
out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)])
|
||||
else:
|
||||
out = out.reshape([-1, num_heads_q, head_dim_v])
|
||||
|
||||
|
||||
@@ -728,7 +728,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
|
||||
current_total_tokens = paddle.sum(seq_lens_this_time)
|
||||
current_total_tokens = forward_meta.ids_remove_padding.shape[0]
|
||||
position_ids = self.position_ids_buffer[:current_total_tokens]
|
||||
mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user