Files
FastDeploy/fastdeploy/model_executor/layers/linear.py
2025-06-16 00:04:48 +08:00

762 lines
28 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import fastdeploy
from paddlenlp.utils.log import logger
import paddle
from paddle import nn
from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
class LinearBase(nn.Layer):
"""
LinearBase Layer
"""
def __init__(
self,
llm_config,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
Args:
llm_config (LLMConfig): Inference-related parameters containing attributes such as
weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int, optional): Number of input features. Defaults to None.
output_size (int, optional): Number of output features. Defaults to None.
weight_key (Any, optional): Key for weights. Defaults to None.
bias_key (Any, optional): Key for biases. Defaults to None.
skip_quant (bool, optional): Whether to skip quantization. Defaults to False.
Raises:
NotImplementedError: Raised if the current platform is not a CUDA platform.
"""
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
else:
raise NotImplementedError
self.llm_config = llm_config
self.skip_quant = skip_quant
self.use_smooth_quant = llm_config.model_config.use_smooth_quant
self.weight_dtype = llm_config.model_config.weight_dtype
self.act_dtype = llm_config.model_config.act_dtype
self.input_size = input_size
self.output_size = output_size
self.with_bias = with_bias
self.add_bias = add_bias
self.prefix = prefix
# key
self.weight_key = f"{prefix}.weight"
self.bias_key = f"{prefix}.bias"
self.shift_key = f"{prefix}.shift_bias"
self.smooth_key = f"{prefix}.smooth_weight"
self.out_scale_key = f"{prefix}.out_scale"
self._dtype = self._helper.get_default_dtype()
if llm_config.quant_config:
self.quant_method = llm_config.quant_config.get_quant_method(self)
self.use_offline_quant = llm_config.tmp_config.use_offline_quant
def is_y_transposed(self):
"""
Returns whether the y tensor should be transposed for inference.
Args:
None.
Returns:
bool, whether the y tensor should be transposed for inference.
"""
if self.weight_dtype == "int4":
return True
if self.weight_dtype == "int8":
return True
if "float8" in self.weight_dtype:
return True
# bf16/fp16/fp32 y is not transposed
return False
def init_weight_shape(self, trans=False):
"""
Initialize the weight shape for the first feedforward network layer.
Args:
trans (bool, optional): Whether to transpose the weight shape.
Defaults to False. If True, the shape will be reversed.
Returns:
None.
"""
self.linear_weight_shape = [
self.input_size,
self.output_size,
]
if trans:
self.linear_weight_shape.reverse()
if self.use_smooth_quant:
self.linear_shift_shape = [self.output_size]
self.linear_smooth_shape = [self.output_size]
if self.weight_dtype == "int4":
self.linear_weight_shape[0] //= 2
def init_weight(self):
"""
Initialize the weights and biases.
"""
self.init_weight_shape(self.is_y_transposed())
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
dtype=self.get_weight_create_dtype(),
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
if self.use_smooth_quant:
self.linear_shift = self.create_parameter(
shape=self.linear_shift_shape,
dtype=self._dtype,
is_bias=False,
)
self.linear_smooth = self.create_parameter(
shape=self.linear_smooth_shape,
dtype=self._dtype,
is_bias=False,
)
def get_weight_create_dtype(self):
"""
Get the data type for creating weights based on quantization settings.
Args:
self (object): The instance of the class where this method is defined.
Returns:
str: The data type for creating weights. It depends on the quantization settings:
- If `self.skip_quant` is True, returns the original data type `self._dtype`.
- If `self.weight_dtype` is "int4", returns "int8" to ensure compatibility or optimization.
- Otherwise, returns the specified weight data type `self.weight_dtype`.
"""
if self.skip_quant:
return self._dtype
if self.weight_dtype == "int4":
return "int8"
# TODO(wangzhe24) create_parameter not support FP8
if "float8" in self.weight_dtype:
return self._dtype
return self.weight_dtype
def load_offline_quant_state_dict(self, quant_weight, quant_scale=None):
"""
Load offline the checkpoint state dictionary into the layer.
"""
if quant_scale is None:
if "float8" in self.weight_dtype:
self.linear_weight.copy_(quant_weight, False)
else:
self.linear_weight.set_value(quant_weight)
else:
if self.inference_args.weight_block_size[0] != -1:
self.linear_weight.copy_(quant_weight.view(paddle.float8_e4m3fn), False)
else:
self.linear_weight.set_value(quant_weight)
self.linear_weight_scale.set_value(quant_scale)
def load_state_dict(self, state_dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.use_offline_quant:
self.load_offline_quant_state_dict(
quant_weight=get_tensor(
state_dict.pop(self.weight_key + ".quant_weight")
),
quant_scale=get_tensor(
state_dict.pop(self.weight_key + ".quant_scale")
),
)
else:
# weight
assert self.weight_key is not None, 'weight_key should not be None.'
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
if self.llm_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.linear_weight.set_value(weight_tensor)
# bias
if self.with_bias:
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
self.linear_bias.set_value(bias_tensor)
# smooth quant
if self.use_smooth_quant:
if self.shift_key in state_dict:
shift_tensor = get_tensor(state_dict.pop(self.shift_key)).astype(
paddle.get_default_dtype()
)
else:
shift_tensor = paddle.zeros(
shape=self.linear_shift_shape,
dtype=paddle.get_default_dtype(),
)
self.linear_shift.set_value(shift_tensor)
if self.smooth_key in state_dict:
smooth_tensor = get_tensor(state_dict.pop(self.smooth_key)).astype(
paddle.get_default_dtype()
)
else:
smooth_tensor = paddle.ones(
shape=[self.linear_smooth_shape],
dtype=paddle.get_default_dtype(),
)
self.linear_smooth.set_value(smooth_tensor)
def forward_cuda(self, x):
"""
Forward function for ColumnParallelLinear.
Args:
x (Tensor): Input tensor to the ColumnParallelLinear layer.
Returns:
Tensor: Output tensor.
Raises:
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
"""
if self.llm_config.quant_config:
linear_out = self.quant_method.apply(self, x)
else:
linear_out = paddle.matmul(x, self.linear_weight)
return linear_out
class ReplicatedLinear(LinearBase):
"""
ReplicatedLinear Layer
"""
def __init__(
self,
llm_config,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
layer_index (int): The index of the linear layer in the model
"""
super().__init__(llm_config=llm_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant)
self.nranks = llm_config.parallel_config.mp_size
self.input_size = input_size
self.init_weight()
self.quant_method.create_weights(self)
def init_weight(self):
"""
Initialize the weights and biases.
"""
self.init_weight_shape(self.is_y_transposed())
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
dtype=self.get_weight_create_dtype(),
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
if self.use_smooth_quant:
self.linear_shift = self.create_parameter(
shape=self.linear_shift_shape,
dtype=self._dtype,
is_bias=False,
)
self.linear_smooth = self.create_parameter(
shape=self.linear_smooth_shape,
dtype=self._dtype,
is_bias=False,
)
class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer
"""
def __init__(
self,
llm_config,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
layer_index (int): The index of the linear layer in the model
"""
super().__init__(llm_config=llm_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant)
self.nranks = llm_config.parallel_config.mp_size
self.input_size = input_size
self.output_size = divide(output_size, self.nranks)
self.init_weight()
self.quant_method.create_weights(self)
def init_weight(self):
"""
Initialize the weights and biases.
"""
self.init_weight_shape(self.is_y_transposed())
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
dtype=self.get_weight_create_dtype(),
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_weight, split_axis=-1)
self.linear_bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_bias, split_axis=-1)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
if self.use_smooth_quant:
self.linear_shift = self.create_parameter(
shape=self.linear_shift_shape,
dtype=self._dtype,
is_bias=False,
)
self.linear_smooth = self.create_parameter(
shape=self.linear_smooth_shape,
dtype=self._dtype,
is_bias=False,
)
class MergedColumnParallelLinear(ColumnParallelLinear):
"""
MergedColumnParallelLinear Layer.
"""
def __init__(
self,
llm_config,
prefix,
with_bias=False,
add_bias=False,
activation="gelu",
use_fast_ffn=False,
skip_quant=False,
):
"""Packed linear layers with column parallelism.
Initialize the fused ffn1 Linear layer with given parameters.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming weights and biases.
weight_key (str): Key name of weight in the pdparams state dict.
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
with_bias (bool, optional): Whether to include bias term. Defaults to True.
activation (str, optional): Activation function to use. Defaults to "gelu".
use_fast_ffn (bool, optional): Whether to use a faster FFN implementation.
Defaults to False.
skip_quant (bool, optional): Whether to skip quantization steps. Defaults to False.
"""
self.use_fast_ffn = use_fast_ffn
self.activation = activation
self.embed_dim = llm_config.model_config.hidden_size
self.dim_feedforward = llm_config.model_config.ffn_hidden_size
self.nranks = llm_config.parallel_config.mp_size
self.dim_feedforward_per_rank = divide(self.dim_feedforward,
self.nranks)
input_size = self.embed_dim
output_size = self.dim_feedforward * 2
super().__init__(llm_config=llm_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant)
def load_state_dict(self, state_dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert self.weight_key is not None, 'weight_key should not be None.'
if self.weight_key in state_dict.keys():
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
else:
gate_weight_key = self.weight_key.replace("up_gate_proj",
"gate_proj")
up_weight_key = self.weight_key.replace("up_gate_proj", "up_proj")
gate_tensor = get_tensor(state_dict.pop(gate_weight_key))
up_tensor = get_tensor(state_dict.pop(up_weight_key))
weight_tensor = paddle.concat([gate_tensor, up_tensor], axis=-1)
if self.with_bias:
gate_bias_key = self.bias_key.replace("up_gate_proj",
"gate_proj")
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
paddle.get_default_dtype())
converted_bias_tensor = paddle.zeros(shape=list(
bias_tensor.shape),
dtype=bias_tensor.dtype)
if not self.use_fast_ffn:
converted_bias_tensor = paddle.concat(
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
else:
converted_bias_tensor = bias_tensor
state_dict[self.bias_key] = converted_bias_tensor
if not self.use_fast_ffn:
converted_weight_tensor = paddle.concat(
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
else:
converted_weight_tensor = weight_tensor
state_dict[self.weight_key] = converted_weight_tensor
super().load_state_dict(state_dict)
class QKVParallelLinear(ColumnParallelLinear):
"""
QKVParallelLinear Layer.
"""
def __init__(self, llm_config, prefix, with_bias=False, add_bias=True):
"""
Initialize the QKV Linear layer with given parameters.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming weights and biases.
weight_key (str): Key name of weight in the pdparams state dict.
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
with_bias (bool, optional): Whether to include bias term. Defaults to True.
skip_quant (bool, optional): Whether to skip quantization steps. Defaults to False.
"""
self.num_heads = llm_config.model_config.num_attention_heads
self.kv_num_heads = llm_config.model_config.num_key_value_heads
self.embed_dim = llm_config.model_config.hidden_size
self.head_dim = llm_config.model_config.head_dim
self.nranks = llm_config.parallel_config.mp_size
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
input_size = self.embed_dim
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
super().__init__(llm_config=llm_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias)
def load_state_dict(self, state_dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert self.weight_key is not None, 'weight_key should not be None.'
# qkv fused in disk
if self.weight_key in state_dict.keys():
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
else:
q_weight_key = self.weight_key.replace("qkv_proj", "q_proj")
k_weight_key = self.weight_key.replace("qkv_proj", "k_proj")
v_weight_key = self.weight_key.replace("qkv_proj", "v_proj")
q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor],
axis=-1).transpose([1, 0])
weight_tensor = weight_tensor.reshape([
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) *
(self.head_dim),
self.embed_dim,
])
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
if self.llm_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.linear_weight.set_value(weight_tensor)
# bias
if self.with_bias:
if self.bias_key in state_dict.keys():
bias_tensor = paddle.to_tensor(
get_tensor(state_dict.pop(self.bias_key)))
self.linear_bias.set_value(bias_tensor)
else:
q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
v_bias_key = self.bias_key.replace("qkv_proj", "v_proj")
q_bias = get_tensor(state_dict.pop(q_bias_key))
k_bias = get_tensor(state_dict.pop(k_bias_key))
v_bias = get_tensor(state_dict.pop(v_bias_key))
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
self.linear_bias.set_value(qkv_bias)
# smooth quant
if self.use_smooth_quant:
if self.shift_key in state_dict:
shift_tensor = get_tensor(state_dict.pop(self.shift_key)).astype(
paddle.get_default_dtype()
)
else:
shift_tensor = paddle.zeros(
shape=self.linear_shift_shape,
dtype=paddle.get_default_dtype(),
)
self.linear_shift.set_value(shift_tensor)
if self.smooth_key in state_dict:
smooth_tensor = get_tensor(state_dict.pop(self.smooth_key)).astype(
paddle.get_default_dtype()
)
else:
smooth_tensor = paddle.ones(
shape=[self.linear_smooth_shape],
dtype=paddle.get_default_dtype(),
)
self.linear_smooth.set_value(smooth_tensor)
class RowParallelLinear(LinearBase):
"""
RowParallelLinear Layer
"""
def __init__(
self,
llm_config,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
layer_index (int): The index of the linear layer in the model
"""
super().__init__(llm_config=llm_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant)
self.llm_config = llm_config
self.skip_quant = False
self.use_smooth_quant = llm_config.model_config.use_smooth_quant
self.weight_dtype = llm_config.model_config.weight_dtype
self.act_dtype = llm_config.model_config.act_dtype
self.nranks = llm_config.parallel_config.mp_size
self.embed_dim = llm_config.model_config.hidden_size
self.head_dim = llm_config.model_config.hidden_size // llm_config.model_config.num_attention_heads
self.num_heads = llm_config.model_config.num_attention_heads // self.nranks
self.dim_feedforward = llm_config.model_config.ffn_hidden_size // self.nranks
self.with_bias = with_bias
self.prefix = prefix
self.shift_key = f"{prefix}.shift_bias"
self.smooth_key = f"{prefix}.smooth_weight"
self.weight_key = f"{prefix}.weight"
self.bias_key = f"{prefix}.bias"
self.weight_only_scale_key = f"{prefix}.weight_only_scale"
self.out_scale_key = f"{prefix}.out_scale"
self._dtype = self._helper.get_default_dtype()
if llm_config.quant_config:
self.quant_method = llm_config.quant_config.get_quant_method(self)
self.quant_method.create_weights(self)
self.init_weight()
def init_weight(self):
"""
Initialize the weights and biases.
"""
self.init_weight_shape(self.is_y_transposed())
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
dtype=self.get_weight_create_dtype(),
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
shape=[self.embed_dim],
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
# row parallel
_set_var_distributed(self.linear_weight, split_axis=0)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
if self.use_smooth_quant:
self.linear_shift = self.create_parameter(
shape=self.linear_shift_shape,
dtype=self._dtype,
is_bias=False,
)
self.linear_smooth = self.create_parameter(
shape=self.linear_smooth_shape,
dtype=self._dtype,
is_bias=False,
)
def forward_cuda(self, x):
if self.llm_config.quant_config:
out = self.quant_method.apply(self, x)
else:
out = paddle.matmul(x, self.linear_weight)
if self.nranks > 1:
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
tensor_model_parallel_all_reduce(out)
return out