mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] support tensor-parallel-size>num_key_value_heads for qwen3 (#2799)
This commit is contained in:
@@ -443,6 +443,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
q_tensor = get_tensor(state_dict.pop(q_weight_key))
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
|
||||
if self.kv_num_heads < self.nranks:
|
||||
sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
|
||||
sharedkv_start = sharedkv_index * self.head_dim
|
||||
sharedkv_end = sharedkv_start + self.head_dim
|
||||
k_tensor = k_tensor[ : , sharedkv_start : sharedkv_end]
|
||||
v_tensor = v_tensor[ : , sharedkv_start : sharedkv_end]
|
||||
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor],
|
||||
axis=-1).transpose([1, 0])
|
||||
weight_tensor = weight_tensor.reshape([
|
||||
|
Reference in New Issue
Block a user