mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 20:11:20 +08:00
qwen loader (#3057)
This commit is contained in:
@@ -22,6 +22,7 @@ from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -80,6 +81,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
@@ -89,6 +91,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
|
||||
self.prefix = prefix
|
||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||
|
@@ -14,11 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.models.utils import (
|
||||
default_weight_loader,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .utils import _set_var_distributed, divide, get_tensor
|
||||
@@ -107,6 +113,15 @@ class LinearBase(nn.Layer):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.bias = self.create_parameter(
|
||||
@@ -115,6 +130,15 @@ class LinearBase(nn.Layer):
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
@@ -273,6 +297,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
self.fd_config = fd_config
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
|
||||
@@ -300,6 +325,15 @@ class ColumnParallelLinear(LinearBase):
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.weight, split_axis=1)
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"output_dim": True,
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
@@ -311,6 +345,17 @@ class ColumnParallelLinear(LinearBase):
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"output_dim": True,
|
||||
"weight_loader": (
|
||||
self.weight_loader
|
||||
if hasattr(self, "weight_loader")
|
||||
else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
@@ -354,6 +399,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
self.activation = activation
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.output_size = output_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
@@ -365,6 +412,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
# 1.fused gate_up in disk
|
||||
# 2.split gate up
|
||||
assert loaded_shard_id in ["gate", "up"]
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None:
|
||||
dim = -1
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if loaded_shard_id == "gate":
|
||||
param[:, : self.output_size // 2] = loaded_weight
|
||||
elif loaded_shard_id == "up":
|
||||
param[:, self.output_size // 2 :] = loaded_weight
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
@@ -415,6 +483,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
|
||||
self.kv_num_heads_per_rank = 1
|
||||
@@ -432,6 +501,34 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
add_bias=add_bias,
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
# 1.fused qkv in disk
|
||||
# 2.split q k v
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None:
|
||||
dim = -1
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if loaded_shard_id == "q":
|
||||
param[:, : self.num_heads_per_rank * self.head_dim] = loaded_weight
|
||||
elif loaded_shard_id == "k":
|
||||
param[
|
||||
:,
|
||||
self.num_heads_per_rank
|
||||
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
|
||||
* self.head_dim,
|
||||
] = loaded_weight
|
||||
elif loaded_shard_id == "v":
|
||||
param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] = loaded_weight
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the weight from the state dictionary.
|
||||
@@ -588,6 +685,18 @@ class RowParallelLinear(LinearBase):
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# row parallel
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"output_dim": False,
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
_set_var_distributed(self.weight, split_axis=0)
|
||||
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
@@ -596,10 +705,18 @@ class RowParallelLinear(LinearBase):
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
if self.nranks > 0:
|
||||
# row parallel
|
||||
_set_var_distributed(self.weight, split_axis=0)
|
||||
if self.nranks > 0:
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": False,
|
||||
"weight_loader": (
|
||||
self.weight_loader
|
||||
if hasattr(self, "weight_loader")
|
||||
else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
|
@@ -22,6 +22,7 @@ from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -83,6 +84,7 @@ class ParallelLMHead(nn.Layer):
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
@@ -93,6 +95,7 @@ class ParallelLMHead(nn.Layer):
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user