[Feature] support tensor-parallel-size>num_key_value_heads for qwen3 (#2799)

This commit is contained in:
zhink
2025-07-11 15:09:43 +08:00
committed by GitHub
parent 2c3607407f
commit c08561c13a
4 changed files with 23 additions and 99 deletions

View File

@@ -443,6 +443,13 @@ class QKVParallelLinear(ColumnParallelLinear):
q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
if self.kv_num_heads < self.nranks:
sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
sharedkv_start = sharedkv_index * self.head_dim
sharedkv_end = sharedkv_start + self.head_dim
k_tensor = k_tensor[ : , sharedkv_start : sharedkv_end]
v_tensor = v_tensor[ : , sharedkv_start : sharedkv_end]
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor],
axis=-1).transpose([1, 0])
weight_tensor = weight_tensor.reshape([