diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index ed2880f8e..6f5dc13ce 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -517,9 +517,11 @@ class QKVParallelLinear(ColumnParallelLinear): self.num_heads_per_rank = divide(self.num_heads, self.nranks) if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0: self.kv_num_heads_per_rank = 1 + self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads) output_size = (self.num_heads + 2 * self.nranks) * self.head_dim else: self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks) + self.num_kv_head_replicas = 1 output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim input_size = self.hidden_size super().__init__( @@ -531,6 +533,14 @@ class QKVParallelLinear(ColumnParallelLinear): add_bias=add_bias, ) + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads_per_rank * self.head_dim, + "k": self.kv_num_heads_per_rank * self.head_dim, + "v": self.kv_num_heads_per_rank * self.head_dim, + } + return shard_size_mapping.get(loaded_shard_id) + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) assert output_dim is not None @@ -557,14 +567,11 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_shard_id in ["q", "k", "v"] # Tensor parallelism splits the weight along the output_dim if self.nranks != 1: + block_size = self._get_shard_size_mapping(loaded_shard_id) dim = -1 if output_dim else 0 - if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): - size = loaded_weight.shape[dim] - else: - size = loaded_weight.get_shape()[dim] - block_size = size // self.nranks - shard_offset = self.local_rank * block_size - shard_size = (self.local_rank + 1) * block_size + shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas + shard_offset = shard_id * block_size + shard_size = (shard_id + 1) * block_size loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) loaded_weight = get_tensor(loaded_weight)