qwen loader (#3057)

This commit is contained in:
bukejiyu
2025-07-30 19:09:38 +08:00
committed by GitHub
parent 28fff1b035
commit db698bda01
22 changed files with 494 additions and 92 deletions

View File

@@ -22,6 +22,7 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from .utils import get_tensor
@@ -80,6 +81,7 @@ class VocabParallelEmbedding(nn.Layer):
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -89,6 +91,7 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)

View File

@@ -14,11 +14,17 @@
# limitations under the License.
"""
from typing import Optional
import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.models.utils import (
default_weight_loader,
set_weight_attrs,
)
from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor
@@ -107,6 +113,15 @@ class LinearBase(nn.Layer):
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
self.weight,
{
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
)
},
)
self.bias = None
if self.with_bias:
self.bias = self.create_parameter(
@@ -115,6 +130,15 @@ class LinearBase(nn.Layer):
is_bias=True,
)
set_weight_attrs(
self.weight,
{
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
)
},
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
@@ -273,6 +297,7 @@ class ColumnParallelLinear(LinearBase):
add_bias=add_bias,
skip_quant=skip_quant,
)
self.fd_config = fd_config
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.input_size = input_size
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
@@ -300,6 +325,15 @@ class ColumnParallelLinear(LinearBase):
if self.nranks > 0:
# col parallel
_set_var_distributed(self.weight, split_axis=1)
set_weight_attrs(
self.weight,
{
"output_dim": True,
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
},
)
self.bias = None
if self.with_bias:
@@ -311,6 +345,17 @@ class ColumnParallelLinear(LinearBase):
if self.nranks > 0:
# col parallel
_set_var_distributed(self.bias, split_axis=1)
set_weight_attrs(
self.weight,
{
"output_dim": True,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
},
)
# smooth quant
self.linear_shift = None
@@ -354,6 +399,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.output_size = output_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
super().__init__(
fd_config=fd_config,
@@ -365,6 +412,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_quant=skip_quant,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "gate":
param[:, : self.output_size // 2] = loaded_weight
elif loaded_shard_id == "up":
param[:, self.output_size // 2 :] = loaded_weight
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
@@ -415,6 +483,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
@@ -432,6 +501,34 @@ class QKVParallelLinear(ColumnParallelLinear):
add_bias=add_bias,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "q":
param[:, : self.num_heads_per_rank * self.head_dim] = loaded_weight
elif loaded_shard_id == "k":
param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
] = loaded_weight
elif loaded_shard_id == "v":
param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] = loaded_weight
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
@@ -588,6 +685,18 @@ class RowParallelLinear(LinearBase):
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# row parallel
set_weight_attrs(
self.weight,
{
"output_dim": False,
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
},
)
_set_var_distributed(self.weight, split_axis=0)
self.bias = None
if self.with_bias:
@@ -596,10 +705,18 @@ class RowParallelLinear(LinearBase):
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
# row parallel
_set_var_distributed(self.weight, split_axis=0)
if self.nranks > 0:
set_weight_attrs(
self.bias,
{
"output_dim": False,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
},
)
# smooth quant
self.linear_shift = None

View File

@@ -22,6 +22,7 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from .utils import get_tensor
@@ -83,6 +84,7 @@ class ParallelLMHead(nn.Layer):
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
@@ -93,6 +95,7 @@ class ParallelLMHead(nn.Layer):
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""