mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
fix qwen3 235B tp 8 (#3697)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -517,9 +517,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
|
||||
self.kv_num_heads_per_rank = 1
|
||||
self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads)
|
||||
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
|
||||
else:
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
self.num_kv_head_replicas = 1
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
input_size = self.hidden_size
|
||||
super().__init__(
|
||||
@@ -531,6 +533,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
add_bias=add_bias,
|
||||
)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads_per_rank * self.head_dim,
|
||||
"k": self.kv_num_heads_per_rank * self.head_dim,
|
||||
"v": self.kv_num_heads_per_rank * self.head_dim,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
@@ -557,14 +567,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
dim = -1 if output_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = (shard_id + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
Reference in New Issue
Block a user