diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b864e4aa3..5f9c47c15 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -196,7 +196,14 @@ class LinearBase(nn.Layer): Args: state_dict (dict): A dictionary containing the weights """ - weight_tensor = get_tensor(state_dict.pop(self.weight_key)) + if "qkv_a_proj_with_mqa" in self.weight_key: + self.weight_key_q = self.weight_key.replace("qkv_a_proj_with_mqa", "q_a_proj") + self.weight_key_kv = self.weight_key.replace("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa") + q_weight_tensor = get_tensor(state_dict.pop(self.weight_key_q)) + kv_weight_tensor = get_tensor(state_dict.pop(self.weight_key_kv)) + weight_tensor = paddle.concat([q_weight_tensor, kv_weight_tensor], axis=-1) + else: + weight_tensor = get_tensor(state_dict.pop(self.weight_key)) self.quant_method.process_loaded_weights(self, weight_tensor) def load_state_dict(self, state_dict: dict): diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 6b28226ed..6926f207f 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -210,11 +210,12 @@ class DeepseekV3MLAAttention(nn.Layer): self.rms_norm_eps = fd_config.model_config.rms_norm_eps if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( + # NOTE: (changwenbin) qkv_a_proj horizontal fusion + self.qkv_a_proj_with_mqa = ReplicatedLinear( fd_config=fd_config, - prefix=f"{prefix}.q_a_proj", + prefix=f"{prefix}.qkv_a_proj_with_mqa", input_size=self.hidden_size, - output_size=self.q_lora_rank, + output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, with_bias=False, ) @@ -235,15 +236,6 @@ class DeepseekV3MLAAttention(nn.Layer): else: assert self.q_lora_rank is not None, "self.q_lora_rank is None, Please Check your config." - # 不切TP,跑 W4A16 Gemm - self.kv_a_proj_with_mqa = ReplicatedLinear( - fd_config=fd_config, - prefix=f"{prefix}.kv_a_proj_with_mqa", - input_size=self.hidden_size, - output_size=self.kv_lora_rank + self.qk_rope_head_dim, - with_bias=False, - ) - self.kv_a_layernorm = RMSNorm( fd_config, hidden_size=self.kv_lora_rank, @@ -331,14 +323,18 @@ class DeepseekV3MLAAttention(nn.Layer): # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. fmha_out = None - query = self.q_a_proj(hidden_states) + + # NOTE: (changwenbin) qkv_a_proj horizontal fusion + qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) + query, compressed_kv, key_pe = qkv_a_out.split( + [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 + ) + query = self.q_a_layernorm(query) query = self.q_b_proj(query) query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) compressed_kv = self.kv_a_layernorm(compressed_kv) @@ -417,9 +413,8 @@ class DeepseekV3MLAAttention(nn.Layer): def load_state_dict(self, state_dict): """ """ - self.q_a_proj.load_state_dict(state_dict) self.q_a_layernorm.load_state_dict(state_dict) - self.kv_a_proj_with_mqa.load_state_dict(state_dict) + self.qkv_a_proj_with_mqa.load_state_dict(state_dict) self.kv_a_layernorm.load_state_dict(state_dict) self.q_b_proj.load_state_dict(state_dict) self.kv_b_proj_bmm.load_state_dict(state_dict)