mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
qkv_a_proj horizontal fusion (#3591)
Support DSK qkv_a_proj horizontal fusion under V0 Loder
This commit is contained in:
@@ -196,7 +196,14 @@ class LinearBase(nn.Layer):
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
if "qkv_a_proj_with_mqa" in self.weight_key:
|
||||
self.weight_key_q = self.weight_key.replace("qkv_a_proj_with_mqa", "q_a_proj")
|
||||
self.weight_key_kv = self.weight_key.replace("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa")
|
||||
q_weight_tensor = get_tensor(state_dict.pop(self.weight_key_q))
|
||||
kv_weight_tensor = get_tensor(state_dict.pop(self.weight_key_kv))
|
||||
weight_tensor = paddle.concat([q_weight_tensor, kv_weight_tensor], axis=-1)
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
|
Reference in New Issue
Block a user