fix qwen3 235B tp 8 (#3697)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-08-28 23:46:25 +08:00
committed by GitHub
parent 4957908275
commit 0b51b9c35b

View File

@@ -517,9 +517,11 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads)
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
self.num_kv_head_replicas = 1
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
super().__init__(
@@ -531,6 +533,14 @@ class QKVParallelLinear(ColumnParallelLinear):
add_bias=add_bias,
)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads_per_rank * self.head_dim,
"k": self.kv_num_heads_per_rank * self.head_dim,
"v": self.kv_num_heads_per_rank * self.head_dim,
}
return shard_size_mapping.get(loaded_shard_id)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
@@ -557,14 +567,11 @@ class QKVParallelLinear(ColumnParallelLinear):
assert loaded_shard_id in ["q", "k", "v"]
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
block_size = self._get_shard_size_mapping(loaded_shard_id)
dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = (shard_id + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)