qkv_a_proj horizontal fusion (#3591)

Support DSK qkv_a_proj horizontal fusion under V0 Loder
This commit is contained in:
AIbin
2025-08-26 14:25:57 +08:00
committed by GitHub
parent 75db0d1ae2
commit 0a0d2959b9
2 changed files with 20 additions and 18 deletions

View File

@@ -196,7 +196,14 @@ class LinearBase(nn.Layer):
Args:
state_dict (dict): A dictionary containing the weights
"""
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
if "qkv_a_proj_with_mqa" in self.weight_key:
self.weight_key_q = self.weight_key.replace("qkv_a_proj_with_mqa", "q_a_proj")
self.weight_key_kv = self.weight_key.replace("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa")
q_weight_tensor = get_tensor(state_dict.pop(self.weight_key_q))
kv_weight_tensor = get_tensor(state_dict.pop(self.weight_key_kv))
weight_tensor = paddle.concat([q_weight_tensor, kv_weight_tensor], axis=-1)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
self.quant_method.process_loaded_weights(self, weight_tensor)
def load_state_dict(self, state_dict: dict):