mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -14,29 +14,25 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import fastdeploy
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .utils import _set_var_distributed, divide, get_tensor
|
||||
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
|
||||
|
||||
class LinearBase(nn.Layer):
|
||||
"""
|
||||
LinearBase Layer
|
||||
LinearBase Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -48,31 +44,26 @@ class LinearBase(nn.Layer):
|
||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Inference-related parameters containing attributes such as
|
||||
weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int, optional): Number of input features. Defaults to None.
|
||||
output_size (int, optional): Number of output features. Defaults to None.
|
||||
weight_key (Any, optional): Key for weights. Defaults to None.
|
||||
bias_key (Any, optional): Key for biases. Defaults to None.
|
||||
skip_quant (bool, optional): Whether to skip quantization. Defaults to False.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if the current platform is not a CUDA platform.
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = skip_quant
|
||||
self.use_smooth_quant = llm_config.model_config.use_smooth_quant
|
||||
self.weight_dtype = llm_config.model_config.weight_dtype
|
||||
self.act_dtype = llm_config.model_config.act_dtype
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.with_bias = with_bias
|
||||
@@ -86,61 +77,27 @@ class LinearBase(nn.Layer):
|
||||
self.out_scale_key = f"{prefix}.out_scale"
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
|
||||
if llm_config.quant_config:
|
||||
self.quant_method = llm_config.quant_config.get_quant_method(self)
|
||||
self.use_offline_quant = llm_config.tmp_config.use_offline_quant
|
||||
|
||||
def is_y_transposed(self):
|
||||
"""
|
||||
Returns whether the y tensor should be transposed for inference.
|
||||
Args:
|
||||
None.
|
||||
|
||||
Returns:
|
||||
bool, whether the y tensor should be transposed for inference.
|
||||
"""
|
||||
if self.weight_dtype == "int4":
|
||||
return True
|
||||
if self.weight_dtype == "int8":
|
||||
return True
|
||||
if "float8" in self.weight_dtype:
|
||||
return True
|
||||
# bf16/fp16/fp32 y is not transposed
|
||||
return False
|
||||
|
||||
def init_weight_shape(self, trans=False):
|
||||
"""
|
||||
Initialize the weight shape for the first feedforward network layer.
|
||||
|
||||
Args:
|
||||
trans (bool, optional): Whether to transpose the weight shape.
|
||||
Defaults to False. If True, the shape will be reversed.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if trans:
|
||||
self.linear_weight_shape.reverse()
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift_shape = [self.output_size]
|
||||
self.linear_smooth_shape = [self.output_size]
|
||||
if self.weight_dtype == "int4":
|
||||
self.linear_weight_shape[0] //= 2
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
if fd_config.model_config.is_quantized:
|
||||
self.weight_key = f"{prefix}.quant_weight"
|
||||
self.weight_scale_key = f"{prefix}.weight_scale"
|
||||
self.act_scale_key = f"{prefix}.activation_scale"
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -156,117 +113,57 @@ class LinearBase(nn.Layer):
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def get_weight_create_dtype(self):
|
||||
def load_prequant_weight(self, state_dict: dict):
|
||||
"""
|
||||
Get the data type for creating weights based on quantization settings.
|
||||
Load the prequantized weight from the state dictionary.
|
||||
|
||||
Args:
|
||||
self (object): The instance of the class where this method is defined.
|
||||
|
||||
Returns:
|
||||
str: The data type for creating weights. It depends on the quantization settings:
|
||||
- If `self.skip_quant` is True, returns the original data type `self._dtype`.
|
||||
- If `self.weight_dtype` is "int4", returns "int8" to ensure compatibility or optimization.
|
||||
- Otherwise, returns the specified weight data type `self.weight_dtype`.
|
||||
state_dict (dict): A dictionary containing the prequantized weights and scales.
|
||||
"""
|
||||
if self.skip_quant:
|
||||
return self._dtype
|
||||
if self.weight_dtype == "int4":
|
||||
return "int8"
|
||||
# TODO(wangzhe24) create_parameter not support FP8
|
||||
if "float8" in self.weight_dtype:
|
||||
return self._dtype
|
||||
return self.weight_dtype
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the weight from the state dictionary.
|
||||
|
||||
def load_offline_quant_state_dict(self, quant_weight, quant_scale=None):
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
Load offline the checkpoint state dictionary into the layer.
|
||||
"""
|
||||
if quant_scale is None:
|
||||
if "float8" in self.weight_dtype:
|
||||
self.linear_weight.copy_(quant_weight, False)
|
||||
else:
|
||||
self.linear_weight.set_value(quant_weight)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
if self.inference_args.weight_block_size[0] != -1:
|
||||
self.linear_weight.copy_(quant_weight.view(paddle.float8_e4m3fn), False)
|
||||
else:
|
||||
self.linear_weight.set_value(quant_weight)
|
||||
self.linear_weight_scale.set_value(quant_scale)
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
if self.use_offline_quant:
|
||||
self.load_offline_quant_state_dict(
|
||||
quant_weight=get_tensor(
|
||||
state_dict.pop(self.weight_key + ".quant_weight")
|
||||
),
|
||||
quant_scale=get_tensor(
|
||||
state_dict.pop(self.weight_key + ".quant_scale")
|
||||
),
|
||||
)
|
||||
# weight
|
||||
self.state_dict = state_dict
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
self.load_prequant_weight(state_dict)
|
||||
else:
|
||||
# weight
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
|
||||
if self.llm_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
self.load_weight(state_dict)
|
||||
|
||||
# bias
|
||||
if self.with_bias:
|
||||
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
|
||||
bias_tensor = paddle.to_tensor(
|
||||
get_tensor(state_dict.pop(self.bias_key)))
|
||||
self.linear_bias.set_value(bias_tensor)
|
||||
|
||||
# smooth quant
|
||||
if self.use_smooth_quant:
|
||||
if self.shift_key in state_dict:
|
||||
shift_tensor = get_tensor(state_dict.pop(self.shift_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_shift.set_value(shift_tensor)
|
||||
if self.smooth_key in state_dict:
|
||||
smooth_tensor = get_tensor(state_dict.pop(self.smooth_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[self.linear_smooth_shape],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_smooth.set_value(smooth_tensor)
|
||||
|
||||
def forward_cuda(self, x):
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward function for ColumnParallelLinear.
|
||||
Forward function for Linear.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor to the ColumnParallelLinear layer.
|
||||
x (Tensor): Input tensor to the Linear.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
@@ -274,22 +171,24 @@ class LinearBase(nn.Layer):
|
||||
Raises:
|
||||
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
|
||||
"""
|
||||
if self.llm_config.quant_config:
|
||||
if self.fd_config.quant_config:
|
||||
linear_out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, self.linear_weight)
|
||||
if self.with_bias:
|
||||
linear_out = paddle.add(linear_out, self.linear_bias)
|
||||
|
||||
return linear_out
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""
|
||||
ReplicatedLinear Layer
|
||||
ReplicatedLinear Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -298,74 +197,39 @@ class ReplicatedLinear(LinearBase):
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
Initializes a replicated linear layer.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
layer_index (int): The index of the linear layer in the model
|
||||
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.input_size = input_size
|
||||
self.init_weight()
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
shape=[self.output_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""
|
||||
ColumnParallelLinear Layer
|
||||
ColumnParallelLinear Layer.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -374,40 +238,45 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
layer_index (int): The index of the linear layer in the model
|
||||
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks)
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
self.quant_method.create_weights(self)
|
||||
self.init_weight()
|
||||
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -429,62 +298,51 @@ class ColumnParallelLinear(LinearBase):
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
MergedColumnParallelLinear Layer.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
prefix,
|
||||
with_bias=False,
|
||||
add_bias=False,
|
||||
activation="gelu",
|
||||
use_fast_ffn=False,
|
||||
skip_quant=False,
|
||||
fd_config: FDConfig,
|
||||
prefix: str,
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
activation: str = "gelu",
|
||||
use_fast_ffn: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
"""
|
||||
Initialize the fused ffn1 Linear layer with given parameters.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
|
||||
prefix (str): Unique name of the layer, used for naming weights and biases.
|
||||
weight_key (str): Key name of weight in the pdparams state dict.
|
||||
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
|
||||
with_bias (bool, optional): Whether to include bias term. Defaults to True.
|
||||
activation (str, optional): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool, optional): Whether to use a faster FFN implementation.
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
activation (str): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool): Whether to use a faster FFN implementation.
|
||||
Defaults to False.
|
||||
skip_quant (bool, optional): Whether to skip quantization steps. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.activation = activation
|
||||
self.embed_dim = llm_config.model_config.hidden_size
|
||||
self.dim_feedforward = llm_config.model_config.ffn_hidden_size
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.dim_feedforward_per_rank = divide(self.dim_feedforward,
|
||||
self.nranks)
|
||||
input_size = self.embed_dim
|
||||
output_size = self.dim_feedforward * 2
|
||||
super().__init__(llm_config=llm_config,
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
@@ -492,7 +350,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -542,47 +400,40 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
QKVParallelLinear Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config, prefix, with_bias=False, add_bias=True):
|
||||
def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
|
||||
"""
|
||||
Initialize the QKV Linear layer with given parameters.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
|
||||
prefix (str): Unique name of the layer, used for naming weights and biases.
|
||||
weight_key (str): Key name of weight in the pdparams state dict.
|
||||
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
|
||||
with_bias (bool, optional): Whether to include bias term. Defaults to True.
|
||||
skip_quant (bool, optional): Whether to skip quantization steps. Defaults to False.
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
|
||||
"""
|
||||
self.num_heads = llm_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = llm_config.model_config.num_key_value_heads
|
||||
self.embed_dim = llm_config.model_config.hidden_size
|
||||
self.head_dim = llm_config.model_config.head_dim
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.num_heads = fd_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
input_size = self.embed_dim
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
Load the weight from the state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
# weight
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
# qkv fused in disk
|
||||
if self.weight_key in state_dict.keys():
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
else:
|
||||
@@ -601,11 +452,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
])
|
||||
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
|
||||
if self.llm_config.quant_config:
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
# weight
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
# qkv fused in disk
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
self.load_prequant_weight(state_dict)
|
||||
else:
|
||||
self.load_weight(state_dict)
|
||||
|
||||
# bias
|
||||
if self.with_bias:
|
||||
if self.bias_key in state_dict.keys():
|
||||
@@ -622,38 +489,25 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
|
||||
self.linear_bias.set_value(qkv_bias)
|
||||
|
||||
# smooth quant
|
||||
if self.use_smooth_quant:
|
||||
if self.shift_key in state_dict:
|
||||
shift_tensor = get_tensor(state_dict.pop(self.shift_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_shift.set_value(shift_tensor)
|
||||
if self.smooth_key in state_dict:
|
||||
smooth_tensor = get_tensor(state_dict.pop(self.smooth_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[self.linear_smooth_shape],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_smooth.set_value(smooth_tensor)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""
|
||||
RowParallelLinear Layer
|
||||
RowParallelLinear Layer.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -665,57 +519,50 @@ class RowParallelLinear(LinearBase):
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
layer_index (int): The index of the linear layer in the model
|
||||
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.use_smooth_quant = llm_config.model_config.use_smooth_quant
|
||||
self.weight_dtype = llm_config.model_config.weight_dtype
|
||||
self.act_dtype = llm_config.model_config.act_dtype
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.embed_dim = llm_config.model_config.hidden_size
|
||||
self.head_dim = llm_config.model_config.hidden_size // llm_config.model_config.num_attention_heads
|
||||
self.num_heads = llm_config.model_config.num_attention_heads // self.nranks
|
||||
self.dim_feedforward = llm_config.model_config.ffn_hidden_size // self.nranks
|
||||
self.with_bias = with_bias
|
||||
self.prefix = prefix
|
||||
self.shift_key = f"{prefix}.shift_bias"
|
||||
self.smooth_key = f"{prefix}.smooth_weight"
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.bias_key = f"{prefix}.bias"
|
||||
self.weight_only_scale_key = f"{prefix}.weight_only_scale"
|
||||
self.out_scale_key = f"{prefix}.out_scale"
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
|
||||
if llm_config.quant_config:
|
||||
self.quant_method = llm_config.quant_config.get_quant_method(self)
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -735,27 +582,159 @@ class RowParallelLinear(LinearBase):
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def forward_cuda(self, x):
|
||||
if self.llm_config.quant_config:
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
out = paddle.matmul(x, self.linear_weight)
|
||||
|
||||
if self.nranks > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class KVBatchLinear(LinearBase):
|
||||
"""
|
||||
KVBatchLinear Layer for handling combined KV projections with bmm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
kv_lora_rank: int = None,
|
||||
num_attention_heads: int = None,
|
||||
qk_nope_head_dim: int = None,
|
||||
v_head_dim: int = None,
|
||||
with_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes a KV batch linear layer that internally splits into K and V projections.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
kv_lora_rank (int): LoRA rank for KV projection. Defaults to None.
|
||||
num_attention_heads (int): Number of attention heads. Defaults to None.
|
||||
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
|
||||
v_head_dim (int): Dimension for V projection. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
|
||||
# Initialize parent with combined dimensions
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=None, # Will be determined from weight shape
|
||||
output_size=None, # Will be determined from weight shape
|
||||
with_bias=with_bias,
|
||||
add_bias=False,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
# Override weight keys to use the combined kv_b_proj
|
||||
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
|
||||
self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight"
|
||||
self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight"
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the combined KV weight and split it into K and V projections
|
||||
"""
|
||||
# Get the combined KV weight
|
||||
# NOTE(Ryan):Do not pop weight_key here, it will be popped in other class
|
||||
kv_weight_tensor = get_tensor(state_dict[self.weight_key])
|
||||
|
||||
# Reshape and split the weight
|
||||
w = kv_weight_tensor.reshape([
|
||||
self.kv_lora_rank,
|
||||
self.num_heads_per_partition,
|
||||
-1,
|
||||
]).transpose(perm=[1, 2, 0])
|
||||
|
||||
# Split into K and V weights
|
||||
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
||||
wk_b = w[:, :self.qk_nope_head_dim, :]
|
||||
|
||||
if self.v_head_dim is None:
|
||||
raise ValueError("self.v_head_dim should not be None")
|
||||
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
|
||||
wv_b = w[:, -self.v_head_dim:, :].transpose(perm=[0, 2, 1])
|
||||
|
||||
# Create K projection weight
|
||||
self.k_b_proj_weight = self.create_parameter(
|
||||
shape=wk_b.shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
# Create V projection weight
|
||||
self.v_b_proj_weight = self.create_parameter(
|
||||
shape=wv_b.shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.k_b_proj_weight.set_value(wk_b)
|
||||
self.v_b_proj_weight.set_value(wv_b)
|
||||
|
||||
def forward_k_b(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward pass for K_b projection using bmm
|
||||
|
||||
Args:
|
||||
x: Input tensor (e.g., query_nope.transpose([1, 0, 2]))
|
||||
|
||||
Returns:
|
||||
K_b projection output
|
||||
"""
|
||||
|
||||
out = paddle.bmm(x, self.k_b_proj_weight)
|
||||
return out
|
||||
|
||||
def forward_v_b(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward pass for V_b projection using bmm
|
||||
|
||||
Args:
|
||||
x: Input tensor (e.g., fmha_out_decode)
|
||||
|
||||
Returns:
|
||||
V_b projection output
|
||||
"""
|
||||
out = paddle.bmm(x, self.v_b_proj_weight)
|
||||
return out
|
||||
|
||||
def forward_cuda(self,
|
||||
x: paddle.Tensor,
|
||||
proj_type: str = 'k') -> paddle.Tensor:
|
||||
"""
|
||||
Forward function that can handle both K and V projections
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
proj_type: 'k' or 'v' to select which projection to use
|
||||
|
||||
Returns:
|
||||
Projection output
|
||||
"""
|
||||
if proj_type == 'k':
|
||||
return self.forward_k_b(x)
|
||||
elif proj_type == 'v':
|
||||
return self.forward_v_b(x)
|
||||
else:
|
||||
raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")
|
||||
|
Reference in New Issue
Block a user