mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)
* support MergedReplicatedLinear * update MergedReplicatedLinear to support DSK_wint4 V1_load * update model name * update linear class * fix * fix v0 moe_bias load --------- Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
KVBatchLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -169,6 +170,13 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
""" """
|
||||
if self.experts.gate_correction_bias is not None:
|
||||
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
|
||||
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
|
||||
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
|
||||
self.experts.gate_correction_bias.shape
|
||||
)
|
||||
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
self.gate.load_state_dict(state_dict)
|
||||
self.experts.load_state_dict(state_dict)
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
@@ -211,11 +219,11 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
|
||||
self.qkv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.qkv_a_proj_with_mqa",
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
with_bias=False,
|
||||
)
|
||||
|
||||
@@ -636,6 +644,8 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
|
||||
("qkv_a_proj_with_mqa", "q_a_proj", "q_a"),
|
||||
("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"),
|
||||
]
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
|
Reference in New Issue
Block a user