mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
qwen3_moe (#3084)
This commit is contained in:
@@ -21,6 +21,7 @@ from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.models.utils import (
|
||||
default_weight_loader,
|
||||
set_weight_attrs,
|
||||
@@ -30,6 +31,45 @@ from fastdeploy.platforms import current_platform
|
||||
from .utils import _set_var_distributed, divide, get_tensor
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(QuantMethodBase):
|
||||
"""Linear method without quantization."""
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
extra_weight_attrs is a dictionary that may include parameters like:
|
||||
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
|
||||
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
|
||||
- weight_loader: a callable or method responsible for loading the weight data
|
||||
"""
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
|
||||
)
|
||||
if hasattr(layer, "nranks") and layer.nranks > 0:
|
||||
split_axis = extra_weight_attrs.get("split_axis")
|
||||
_set_var_distributed(layer.weight, split_axis=split_axis)
|
||||
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||
if layer.weight.dtype != weights.dtype:
|
||||
weights = weights.cast(layer.weight.dtype)
|
||||
layer.weight.set_value(weights)
|
||||
|
||||
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
|
||||
|
||||
linear_out = paddle.matmul(x, layer.weight)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
return linear_out
|
||||
|
||||
|
||||
class LinearBase(nn.Layer):
|
||||
"""
|
||||
LinearBase Layer.
|
||||
@@ -44,6 +84,8 @@ class LinearBase(nn.Layer):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
weight_key: str = "",
|
||||
):
|
||||
"""
|
||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||
@@ -81,46 +123,31 @@ class LinearBase(nn.Layer):
|
||||
self.add_bias = add_bias
|
||||
self.prefix = prefix
|
||||
# key
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
if weight_key:
|
||||
self.weight_key = f"{prefix}.{weight_key}"
|
||||
else:
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.bias_key = f"{prefix}.bias"
|
||||
self.shift_key = f"{prefix}.shift_bias"
|
||||
self.smooth_key = f"{prefix}.smooth_weight"
|
||||
self.out_scale_key = f"{prefix}.out_scale"
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
if weight_dtype:
|
||||
self.weight_dtype = weight_dtype
|
||||
elif self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
else:
|
||||
self.weight_dtype = self._dtype
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
|
||||
if fd_config.quant_config and not skip_quant:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
if fd_config.model_config.is_quantized:
|
||||
self.weight_key = f"{prefix}.quant_weight"
|
||||
self.weight_scale_key = f"{prefix}.weight_scale"
|
||||
self.act_scale_key = f"{prefix}.activation_scale"
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.weight = self.create_parameter(
|
||||
shape=self.weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
)
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()
|
||||
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
@@ -130,19 +157,15 @@ class LinearBase(nn.Layer):
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
|
||||
if fd_config.model_config.is_quantized:
|
||||
self.weight_key = f"{prefix}.quant_weight"
|
||||
self.weight_scale_key = f"{prefix}.weight_scale"
|
||||
self.act_scale_key = f"{prefix}.activation_scale"
|
||||
|
||||
def load_prequant_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the prequantized weight from the state dictionary.
|
||||
@@ -160,11 +183,7 @@ class LinearBase(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.weight.set_value(weight_tensor)
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -199,12 +218,7 @@ class LinearBase(nn.Layer):
|
||||
Raises:
|
||||
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
|
||||
"""
|
||||
if self.fd_config.quant_config:
|
||||
linear_out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, self.weight)
|
||||
if self.with_bias:
|
||||
linear_out = paddle.add(linear_out, self.bias)
|
||||
linear_out = self.quant_method.apply(self, x)
|
||||
|
||||
return linear_out
|
||||
|
||||
@@ -223,6 +237,8 @@ class ReplicatedLinear(LinearBase):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
weight_key: str = "",
|
||||
):
|
||||
"""
|
||||
Initializes a replicated linear layer.
|
||||
@@ -245,6 +261,8 @@ class ReplicatedLinear(LinearBase):
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
weight_key=weight_key,
|
||||
)
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
@@ -252,9 +270,14 @@ class ReplicatedLinear(LinearBase):
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
self.quant_method.create_weights(self)
|
||||
self.init_weight()
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
weight_loader=(
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
@@ -306,60 +329,22 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
self.quant_method.create_weights(self)
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.weight = self.create_parameter(
|
||||
shape=self.weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
split_axis=1,
|
||||
output_dim=True,
|
||||
weight_loader=(
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.weight, split_axis=1)
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"output_dim": True,
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.output_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
{
|
||||
"output_dim": True,
|
||||
"weight_loader": (
|
||||
self.weight_loader
|
||||
if hasattr(self, "weight_loader")
|
||||
else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
set_weight_attrs(self.bias, {"output_dim": True})
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
@@ -429,9 +414,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if loaded_shard_id == "gate":
|
||||
param[:, : self.output_size // 2] = loaded_weight
|
||||
param = param[:, : self.output_size // 2]
|
||||
elif loaded_shard_id == "up":
|
||||
param[:, self.output_size // 2 :] = loaded_weight
|
||||
param = param[:, self.output_size // 2 :]
|
||||
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -518,16 +508,21 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if loaded_shard_id == "q":
|
||||
param[:, : self.num_heads_per_rank * self.head_dim] = loaded_weight
|
||||
param = param[:, : self.num_heads_per_rank * self.head_dim]
|
||||
elif loaded_shard_id == "k":
|
||||
param[
|
||||
param = param[
|
||||
:,
|
||||
self.num_heads_per_rank
|
||||
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
|
||||
* self.head_dim,
|
||||
] = loaded_weight
|
||||
]
|
||||
elif loaded_shard_id == "v":
|
||||
param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] = loaded_weight
|
||||
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
|
||||
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
@@ -665,62 +660,25 @@ class RowParallelLinear(LinearBase):
|
||||
]
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
self.weight = self.create_parameter(
|
||||
shape=self.weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
split_axis=0,
|
||||
output_dim=False,
|
||||
weight_loader=(
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# row parallel
|
||||
|
||||
if self.with_bias:
|
||||
_set_var_distributed(self.bias, split_axis=0)
|
||||
set_weight_attrs(
|
||||
self.weight,
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": False,
|
||||
"weight_loader": (
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
_set_var_distributed(self.weight, split_axis=0)
|
||||
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": False,
|
||||
"weight_loader": (
|
||||
self.weight_loader
|
||||
if hasattr(self, "weight_loader")
|
||||
else default_weight_loader(self.fd_config)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.fd_config.quant_config:
|
||||
|
Reference in New Issue
Block a user