Files
FastDeploy/fastdeploy/model_executor/layers/linear.py
lizhenyun01 d40a1046de
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
[Feature] support rl_tp_degree (#3934)
* [Feature] support rl_tp_degree

* add rl_tp_degree in lmhead

* add rl_tp_degree in bias

* fix split_axis=0 in bias

* fix split_axis in weight

* fix bias rl_tp_degree

* fix bias rl_tp_degree

* change attr to dict

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-08 16:20:32 +08:00

945 lines
36 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import numpy as np
import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.utils import (
default_weight_loader,
set_weight_attrs,
slice_fn,
)
from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor
class UnquantizedLinearMethod(QuantMethodBase):
"""Linear method without quantization."""
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
extra_weight_attrs is a dictionary that may include parameters like:
- split_axis: axis along which to split the tensor in a distributed environment
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data
"""
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
split_axis = extra_weight_attrs.get("split_axis")
if hasattr(layer, "nranks") and layer.nranks > 0:
_set_var_distributed(layer.weight, split_axis=split_axis)
set_weight_attrs(
layer.weight,
{
**extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
},
)
def process_loaded_weights(self, layer, weights) -> None:
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if layer.weight.dtype != weights.dtype:
weights = weights.cast(layer.weight.dtype)
layer.weight.set_value(weights)
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
linear_out = paddle.matmul(x, layer.weight)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)
return linear_out
class LinearBase(nn.Layer):
"""
LinearBase Layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
Raises:
NotImplementedError: Raised if the current platform is not a CUDA platform.
"""
super().__init__()
if (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
else:
raise NotImplementedError
self.fd_config = fd_config
self.skip_quant = skip_quant
self.input_size = input_size
self.output_size = output_size
self.with_bias = with_bias
self.add_bias = add_bias
self.prefix = prefix
# key
if weight_key:
self.weight_key = f"{prefix}.{weight_key}"
elif fd_config.model_config.is_quantized and not skip_quant:
self.weight_key = f"{prefix}.quant_weight"
self.weight_scale_key = f"{prefix}.weight_scale"
self.act_scale_key = f"{prefix}.activation_scale"
else:
self.weight_key = f"{prefix}.weight"
self.bias_key = f"{prefix}.bias"
self.shift_key = f"{prefix}.shift_bias"
self.smooth_key = f"{prefix}.smooth_weight"
self.out_scale_key = f"{prefix}.out_scale"
self._dtype = self._helper.get_default_dtype()
if weight_dtype:
self.weight_dtype = weight_dtype
elif self.skip_quant:
self.weight_dtype = self._dtype
else:
self.weight_dtype = self._dtype
self.weight_shape = [
self.input_size,
self.output_size,
]
if fd_config.quant_config and not skip_quant:
self.quant_method = fd_config.quant_config.get_quant_method(self)
else:
self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()
self.bias = None
if self.with_bias:
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
setattr(
self.bias,
"weight_loader",
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config),
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
def load_prequant_weight(self, state_dict: dict):
"""
Load the prequantized weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the prequantized weights and scales.
"""
if isinstance(self.quant_method, UnquantizedLinearMethod):
# for gate
self.load_weight(state_dict)
else:
self.quant_method.process_prequanted_weights(self, state_dict)
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the weights
"""
if "qkv_a_proj_with_mqa" in self.weight_key:
self.weight_key_q = self.weight_key.replace("qkv_a_proj_with_mqa", "q_a_proj")
self.weight_key_kv = self.weight_key.replace("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa")
q_weight_tensor = get_tensor(state_dict.pop(self.weight_key_q))
kv_weight_tensor = get_tensor(state_dict.pop(self.weight_key_kv))
weight_tensor = paddle.concat([q_weight_tensor, kv_weight_tensor], axis=-1)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
self.quant_method.process_loaded_weights(self, weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
self.state_dict = state_dict
assert self.weight_key is not None, "weight_key should not be None."
if self.fd_config.model_config.is_quantized:
self.load_prequant_weight(state_dict)
else:
self.load_weight(state_dict)
# bias
if self.with_bias:
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
self.bias.set_value(bias_tensor)
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
"""
Forward function for Linear.
Args:
x (Tensor): Input tensor to the Linear.
Returns:
Tensor: Output tensor.
Raises:
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
"""
linear_out = self.quant_method.apply(self, x)
return linear_out
class ReplicatedLinear(LinearBase):
"""
ReplicatedLinear Layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a replicated linear layer.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)
self.hidden_size = fd_config.model_config.hidden_size
assert self.quant_method is not None
self.quant_method.create_weights(
self,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
)
class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.input_size = input_size
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=self.input_size,
output_size=self.output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
)
assert self.quant_method is not None
self.quant_method.create_weights(
self,
output_dim=True,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
)
if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=-1)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
set_weight_attrs(self.bias, {"output_dim": True})
# set_rl_tp_degree
set_weight_attrs(
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)
if self.with_bias:
set_weight_attrs(
self.bias, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)
class MergedColumnParallelLinear(ColumnParallelLinear):
"""
MergedColumnParallelLinear Layer.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str,
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
activation: str = "gelu",
skip_quant: bool = False,
):
"""
Initialize the fused up_gate_proj Linear layer with given parameters.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
activation (str): Activation function to use. Defaults to "gelu".
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.output_size = output_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
shard_dim = -1 if output_dim else 0
output_size = param.shape[shard_dim]
if loaded_shard_id is None:
# Loaded weight is already fused on disk.
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size * self.nranks // 2),
("up", output_size * self.nranks // 2, output_size * self.nranks // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
# split gate up
assert loaded_shard_id in ["gate", "up"]
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
param_shard_size = output_size // 2
if loaded_shard_id == "gate":
param_shard_offset = 0
else:
# loaded_shard_id == "up"
param_shard_offset = param_shard_size
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert self.weight_key is not None, "weight_key should not be None."
if self.weight_key in state_dict.keys():
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
else:
gate_weight_key = self.weight_key.replace("up_gate_proj", "gate_proj")
up_weight_key = self.weight_key.replace("up_gate_proj", "up_proj")
gate_tensor = get_tensor(state_dict.pop(gate_weight_key))
up_tensor = get_tensor(state_dict.pop(up_weight_key))
weight_tensor = paddle.concat([gate_tensor, up_tensor], axis=-1)
if self.with_bias:
gate_bias_key = self.bias_key.replace("up_gate_proj", "gate_proj")
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(paddle.get_default_dtype())
state_dict[self.bias_key] = bias_tensor
state_dict[self.weight_key] = weight_tensor
super().load_state_dict(state_dict)
class QKVParallelLinear(ColumnParallelLinear):
"""
QKVParallelLinear Layer.
"""
def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
"""
Initialize the QKV Linear layer with given parameters.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
"""
self.num_heads = fd_config.model_config.num_attention_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads)
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
self.num_kv_head_replicas = 1
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads_per_rank * self.head_dim,
"k": self.kv_num_heads_per_rank * self.head_dim,
"v": self.kv_num_heads_per_rank * self.head_dim,
}
return shard_size_mapping.get(loaded_shard_id)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
if loaded_shard_id is None:
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
block_size = self._get_shard_size_mapping(loaded_shard_id)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = (shard_id + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the weights
"""
if self.weight_key in state_dict.keys():
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
else:
q_weight_key = self.weight_key.replace("qkv_proj", "q_proj")
k_weight_key = self.weight_key.replace("qkv_proj", "k_proj")
v_weight_key = self.weight_key.replace("qkv_proj", "v_proj")
q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
if self.kv_num_heads < self.nranks:
sharedkv_index = (
self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads
) // self.nranks
sharedkv_start = sharedkv_index * self.head_dim
sharedkv_end = sharedkv_start + self.head_dim
k_tensor = k_tensor[:, sharedkv_start:sharedkv_end]
v_tensor = v_tensor[:, sharedkv_start:sharedkv_end]
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], axis=-1).transpose([1, 0])
weight_tensor = weight_tensor.reshape(
[
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * (self.head_dim),
self.hidden_size,
]
)
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.weight.set_value(weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert self.weight_key is not None, "weight_key should not be None."
# qkv fused in disk
if self.fd_config.model_config.is_quantized:
self.load_prequant_weight(state_dict)
else:
self.load_weight(state_dict)
# bias
if self.with_bias:
if self.bias_key in state_dict.keys():
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
self.bias.set_value(bias_tensor)
else:
q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
v_bias_key = self.bias_key.replace("qkv_proj", "v_proj")
q_bias = get_tensor(state_dict.pop(q_bias_key))
k_bias = get_tensor(state_dict.pop(k_bias_key))
v_bias = get_tensor(state_dict.pop(v_bias_key))
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
self.bias.set_value(qkv_bias)
class RowParallelLinear(LinearBase):
"""
RowParallelLinear Layer.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
reduce_results: bool = True,
skip_quant: bool = False,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.skip_quant = False
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
# Split input_size when using TP inference.
self.input_size = divide(input_size, self.nranks)
self.output_size = output_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=self.input_size,
output_size=self.output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
)
if add_bias:
assert with_bias, "with_bias must be True when add_bias is True."
assert self.quant_method is not None
self.quant_method.create_weights(
self,
split_axis=0,
output_dim=False,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
)
if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=0)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
set_weight_attrs(
self.bias,
{
"output_dim": False,
},
)
self.reduce_results = reduce_results
# set_rl_tp_degree
set_weight_attrs(
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
if self.fd_config.quant_config:
out = self.quant_method.apply(self, x)
else:
out = paddle.matmul(x, self.weight)
if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out, self.tp_group)
if not self.fd_config.quant_config and self.add_bias:
out = paddle.add(out, self.bias)
return out
class KVBatchLinear(nn.Layer):
"""
KVBatchLinear Layer for handling combined KV projections with bmm.
"""
def __init__(
self,
fd_config: FDConfig,
kv_b_proj: nn.Layer,
prefix: str = "",
kv_lora_rank: int = None,
num_attention_heads: int = None,
qk_nope_head_dim: int = None,
v_head_dim: int = None,
):
"""
Initializes a KV batch linear layer that internally splits into K and V projections.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
kv_lora_rank (int): LoRA rank for KV projection. Defaults to None.
num_attention_heads (int): Number of attention heads. Defaults to None.
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
v_head_dim (int): Dimension for V projection. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.kv_lora_rank = kv_lora_rank
self.num_attention_heads = num_attention_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
# Split num_attention_heads when using TP inference.
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.kv_b_proj = kv_b_proj
self.weight_dtype = self._helper.get_default_dtype()
# Override weight keys to use the combined kv_b_proj
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
def process_weights_after_loading(self):
w = self.kv_b_proj.weight.reshape(
[
self.kv_lora_rank,
self.num_heads_per_partition,
-1,
]
).transpose(perm=[1, 2, 0])
self.kv_b_proj = None
if w.dtype != self.weight_dtype:
w = w.cast(self.weight_dtype)
# Split into K and V weights
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
wk_b = w[:, : self.qk_nope_head_dim, :]
if self.v_head_dim is None:
raise ValueError("self.v_head_dim should not be None")
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1])
self.k_b_proj_weight = wk_b
self.v_b_proj_weight = wv_b
def load_state_dict(self, state_dict: dict):
"""
Load the combined KV weight and split it into K and V projections
"""
# Get the combined KV weight
# NOTE(Ryan):Do not pop weight_key here, it will be popped in other class
kv_weight_tensor = get_tensor(state_dict[self.weight_key])
# Reshape and split the weight
w = kv_weight_tensor.reshape(
[
self.kv_lora_rank,
self.num_heads_per_partition,
-1,
]
).transpose(perm=[1, 2, 0])
# Split into K and V weights
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
wk_b = w[:, : self.qk_nope_head_dim, :]
if self.v_head_dim is None:
raise ValueError("self.v_head_dim should not be None")
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1])
# Create K projection weight
self.k_b_proj_weight = self.create_parameter(
shape=wk_b.shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
# Create V projection weight
self.v_b_proj_weight = self.create_parameter(
shape=wv_b.shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.k_b_proj_weight.set_value(wk_b)
self.v_b_proj_weight.set_value(wv_b)
def forward_k_b(self, x: paddle.Tensor) -> paddle.Tensor:
"""
Forward pass for K_b projection using bmm
Args:
x: Input tensor (e.g., query_nope.transpose([1, 0, 2]))
Returns:
K_b projection output
"""
out = paddle.bmm(x, self.k_b_proj_weight)
return out
def forward_v_b(self, x: paddle.Tensor) -> paddle.Tensor:
"""
Forward pass for V_b projection using bmm
Args:
x: Input tensor (e.g., fmha_out_decode)
Returns:
V_b projection output
"""
out = paddle.bmm(x, self.v_b_proj_weight)
return out
def forward(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor:
"""
Forward function that can handle both K and V projections
Args:
x: Input tensor
proj_type: 'k' or 'v' to select which projection to use
Returns:
Projection output
"""
if proj_type == "k":
return self.forward_k_b(x)
elif proj_type == "v":
return self.forward_v_b(x)
else:
raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")