qkv_a_proj horizontal fusion (#3591)

Support DSK qkv_a_proj horizontal fusion under V0 Loder
This commit is contained in:
AIbin
2025-08-26 14:25:57 +08:00
committed by GitHub
parent 75db0d1ae2
commit 0a0d2959b9
2 changed files with 20 additions and 18 deletions

View File

@@ -210,11 +210,12 @@ class DeepseekV3MLAAttention(nn.Layer):
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
self.qkv_a_proj_with_mqa = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.q_a_proj",
prefix=f"{prefix}.qkv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.q_lora_rank,
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
with_bias=False,
)
@@ -235,15 +236,6 @@ class DeepseekV3MLAAttention(nn.Layer):
else:
assert self.q_lora_rank is not None, "self.q_lora_rank is None, Please Check your config."
# 不切TP,跑 W4A16 Gemm
self.kv_a_proj_with_mqa = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.kv_lora_rank + self.qk_rope_head_dim,
with_bias=False,
)
self.kv_a_layernorm = RMSNorm(
fd_config,
hidden_size=self.kv_lora_rank,
@@ -331,14 +323,18 @@ class DeepseekV3MLAAttention(nn.Layer):
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
fmha_out = None
query = self.q_a_proj(hidden_states)
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
@@ -417,9 +413,8 @@ class DeepseekV3MLAAttention(nn.Layer):
def load_state_dict(self, state_dict):
""" """
self.q_a_proj.load_state_dict(state_dict)
self.q_a_layernorm.load_state_dict(state_dict)
self.kv_a_proj_with_mqa.load_state_dict(state_dict)
self.qkv_a_proj_with_mqa.load_state_dict(state_dict)
self.kv_a_layernorm.load_state_dict(state_dict)
self.q_b_proj.load_state_dict(state_dict)
self.kv_b_proj_bmm.load_state_dict(state_dict)