【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)

* support MergedReplicatedLinear

* update MergedReplicatedLinear to support DSK_wint4 V1_load

* update model name

* update linear class

* fix

* fix v0 moe_bias load

---------

Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
AIbin
2025-09-05 12:16:05 +08:00
committed by GitHub
parent b23fc654d9
commit 41aee08982
4 changed files with 102 additions and 4 deletions

View File

@@ -38,6 +38,7 @@ from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear,
KVBatchLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear,
)
@@ -169,6 +170,13 @@ class DeepSeekV3MoE(nn.Layer):
def load_state_dict(self, state_dict):
""" """
if self.experts.gate_correction_bias is not None:
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
self.experts.gate_correction_bias.shape
)
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
@@ -211,11 +219,11 @@ class DeepseekV3MLAAttention(nn.Layer):
if self.q_lora_rank is not None:
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
self.qkv_a_proj_with_mqa = ReplicatedLinear(
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.qkv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
with_bias=False,
)
@@ -636,6 +644,8 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
("qkv_a_proj_with_mqa", "q_a_proj", "q_a"),
("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"),
]
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(