mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Bug Fix] Fix bug of MLA Attention Backend (#3178)
* fix typo * fix mla attention backend
This commit is contained in:
@@ -315,7 +315,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
dtype=layernorm_out.dtype,
|
dtype=layernorm_out.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_meta.max_enc_len_this_time:
|
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||||
query = self.q_a_proj(layernorm_out)
|
query = self.q_a_proj(layernorm_out)
|
||||||
query = self.q_a_layernorm(query)
|
query = self.q_a_layernorm(query)
|
||||||
query = self.q_b_proj(query)
|
query = self.q_b_proj(query)
|
||||||
@@ -362,7 +362,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
|
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
|
||||||
|
|
||||||
fmha_out = fmha_out + fmha_out_prefill
|
fmha_out = fmha_out + fmha_out_prefill
|
||||||
if forward_meta.max_dec_len_this_time:
|
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
|
||||||
query = self.q_a_proj(layernorm_out)
|
query = self.q_a_proj(layernorm_out)
|
||||||
query = self.q_a_layernorm(query)
|
query = self.q_a_layernorm(query)
|
||||||
ln_out_or_q_c = query
|
ln_out_or_q_c = query
|
||||||
|
Reference in New Issue
Block a user