mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	 61b3997b85
			
		
	
	61b3997b85
	
	
		
			
	
		
	
	
		
			Some checks failed
		
		
	
	Deploy GitHub Pages / deploy (push) Has been cancelled
				
			* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
		
			
				
	
	
		
			755 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			755 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import paddle
 | |
| from paddle import nn
 | |
| 
 | |
| from fastdeploy.config import FDConfig
 | |
| from fastdeploy.distributed.communication_op import \
 | |
|     tensor_model_parallel_all_reduce
 | |
| from fastdeploy.platforms import current_platform
 | |
| 
 | |
| from .utils import _set_var_distributed, divide, get_tensor
 | |
| 
 | |
| 
 | |
| class LinearBase(nn.Layer):
 | |
|     """
 | |
|     LinearBase Layer.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         prefix: str = "",
 | |
|         input_size: int = None,
 | |
|         output_size: int = None,
 | |
|         with_bias: bool = False,
 | |
|         add_bias: bool = False,
 | |
|         skip_quant: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initializes a linear layer and provides additional parameters required for inference and quantization.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|                 Can be arbitrarily named.
 | |
|             input_size (int): Number of input features. Defaults to None.
 | |
|             output_size (int): Number of output features. Defaults to None.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
 | |
|             skip_quant (bool): Whether to skip quantization. Defaults to False.
 | |
| 
 | |
|         Raises:
 | |
|             NotImplementedError: Raised if the current platform is not a CUDA platform.
 | |
|         """
 | |
|         super().__init__()
 | |
|         if current_platform.is_cuda() or current_platform.is_xpu(
 | |
|         ) or current_platform.is_iluvatar() or current_platform.is_gcu():
 | |
|             self.forward = self.forward_cuda
 | |
|         else:
 | |
|             raise NotImplementedError
 | |
| 
 | |
|         self.fd_config = fd_config
 | |
|         self.skip_quant = skip_quant
 | |
|         self.input_size = input_size
 | |
|         self.output_size = output_size
 | |
|         self.with_bias = with_bias
 | |
|         self.add_bias = add_bias
 | |
|         self.prefix = prefix
 | |
|         # key
 | |
|         self.weight_key = f"{prefix}.weight"
 | |
|         self.bias_key = f"{prefix}.bias"
 | |
|         self.shift_key = f"{prefix}.shift_bias"
 | |
|         self.smooth_key = f"{prefix}.smooth_weight"
 | |
|         self.out_scale_key = f"{prefix}.out_scale"
 | |
| 
 | |
|         self._dtype = self._helper.get_default_dtype()
 | |
|         self.weight_dtype = self._dtype
 | |
|         self.weight_shape = [
 | |
|             self.input_size,
 | |
|             self.output_size,
 | |
|         ]
 | |
|         if fd_config.quant_config:
 | |
|             self.quant_method = fd_config.quant_config.get_quant_method(self)
 | |
|         if fd_config.model_config.is_quantized:
 | |
|             self.weight_key = f"{prefix}.quant_weight"
 | |
|             self.weight_scale_key = f"{prefix}.weight_scale"
 | |
|             self.act_scale_key = f"{prefix}.activation_scale"
 | |
| 
 | |
|     def init_weight(self):
 | |
|         """
 | |
|         Initialize the weights and biases.
 | |
|         """
 | |
|         if self.skip_quant:
 | |
|             self.weight_dtype = self._dtype
 | |
|         self.weight = self.create_parameter(
 | |
|             shape=self.weight_shape,
 | |
|             dtype=self.weight_dtype,
 | |
|             is_bias=False,
 | |
|             default_initializer=paddle.nn.initializer.Constant(0),
 | |
|         )
 | |
| 
 | |
|         self.bias = None
 | |
|         if self.with_bias:
 | |
|             self.bias = self.create_parameter(
 | |
|                 shape=[self.output_size],
 | |
|                 dtype=self._dtype,
 | |
|                 is_bias=True,
 | |
|             )
 | |
| 
 | |
|         # smooth quant
 | |
|         self.linear_shift = None
 | |
|         self.linear_smooth = None
 | |
| 
 | |
|     def load_prequant_weight(self, state_dict: dict):
 | |
|         """
 | |
|         Load the prequantized weight from the state dictionary.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the prequantized weights and scales.
 | |
|         """
 | |
|         self.quant_method.process_prequanted_weights(self, state_dict)
 | |
| 
 | |
|     def load_weight(self, state_dict: dict):
 | |
|         """
 | |
|         Load the weight from the state dictionary.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the weights
 | |
|         """
 | |
|         weight_tensor = get_tensor(state_dict.pop(self.weight_key))
 | |
| 
 | |
|         if self.fd_config.quant_config:
 | |
|             self.quant_method.process_loaded_weights(self, weight_tensor)
 | |
|         else:
 | |
|             self.weight.set_value(weight_tensor)
 | |
| 
 | |
|     def load_state_dict(self, state_dict: dict):
 | |
|         """
 | |
|         Load the checkpoint state dictionary into the layer.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the checkpoint weights and biases.
 | |
|         """
 | |
|         # weight
 | |
|         self.state_dict = state_dict
 | |
|         assert self.weight_key is not None, 'weight_key should not be None.'
 | |
|         if self.fd_config.model_config.is_quantized:
 | |
|             self.load_prequant_weight(state_dict)
 | |
|         else:
 | |
|             self.load_weight(state_dict)
 | |
| 
 | |
|         # bias
 | |
|         if self.with_bias:
 | |
|             bias_tensor = paddle.to_tensor(
 | |
|                 get_tensor(state_dict.pop(self.bias_key)))
 | |
|             self.bias.set_value(bias_tensor)
 | |
| 
 | |
|     def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
 | |
|         """
 | |
|         Forward function for Linear.
 | |
| 
 | |
|         Args:
 | |
|             x (Tensor): Input tensor to the Linear.
 | |
| 
 | |
|         Returns:
 | |
|             Tensor: Output tensor.
 | |
| 
 | |
|         Raises:
 | |
|             NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
 | |
|         """
 | |
|         if self.fd_config.quant_config:
 | |
|             linear_out = self.quant_method.apply(self, x)
 | |
|         else:
 | |
|             linear_out = paddle.matmul(x, self.weight)
 | |
|             if self.with_bias:
 | |
|                 linear_out = paddle.add(linear_out, self.bias)
 | |
| 
 | |
|         return linear_out
 | |
| 
 | |
| 
 | |
| class ReplicatedLinear(LinearBase):
 | |
|     """
 | |
|     ReplicatedLinear Layer.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         prefix: str = "",
 | |
|         input_size: int = None,
 | |
|         output_size: int = None,
 | |
|         with_bias: bool = False,
 | |
|         add_bias: bool = False,
 | |
|         skip_quant: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initializes a replicated linear layer.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|                 Can be arbitrarily named.
 | |
|             input_size (int): Number of input features. Defaults to None.
 | |
|             output_size (int): Number of output features. Defaults to None.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
 | |
|             skip_quant (bool): Whether to skip quantization. Defaults to False.
 | |
|         """
 | |
|         super().__init__(fd_config=fd_config,
 | |
|                          prefix=prefix,
 | |
|                          input_size=input_size,
 | |
|                          output_size=output_size,
 | |
|                          with_bias=with_bias,
 | |
|                          add_bias=add_bias,
 | |
|                          skip_quant=skip_quant)
 | |
| 
 | |
|         self.hidden_size = fd_config.model_config.hidden_size
 | |
|         self.weight_shape = [
 | |
|             self.input_size,
 | |
|             self.output_size,
 | |
|         ]
 | |
|         if fd_config.quant_config:
 | |
|             self.quant_method.create_weights(self)
 | |
|         self.init_weight()
 | |
| 
 | |
| 
 | |
| class ColumnParallelLinear(LinearBase):
 | |
|     """
 | |
|     ColumnParallelLinear Layer.
 | |
| 
 | |
|     The linear layer is defined as Y = XA + b. A is parallelized along
 | |
|     its second dimension as A = [A_1, ..., A_p].
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         prefix: str = "",
 | |
|         input_size: int = None,
 | |
|         output_size: int = None,
 | |
|         with_bias: bool = False,
 | |
|         add_bias: bool = False,
 | |
|         skip_quant: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initializes a linear layer and provides additional parameters required for inference and quantization.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|                 Can be arbitrarily named.
 | |
|             input_size (int): Number of input features. Defaults to None.
 | |
|             output_size (int): Number of output features. Defaults to None.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
 | |
|             skip_quant (bool): Whether to skip quantization. Defaults to False.
 | |
|         """
 | |
|         super().__init__(fd_config=fd_config,
 | |
|                          prefix=prefix,
 | |
|                          input_size=input_size,
 | |
|                          output_size=output_size,
 | |
|                          with_bias=with_bias,
 | |
|                          add_bias=add_bias,
 | |
|                          skip_quant=skip_quant)
 | |
|         self.nranks = fd_config.parallel_config.tensor_parallel_size
 | |
|         self.input_size = input_size
 | |
|         self.output_size = divide(
 | |
|             output_size,
 | |
|             self.nranks)  # Split the output_size using TP inference.
 | |
|         self.hidden_size = fd_config.model_config.hidden_size
 | |
|         self.weight_shape = [
 | |
|             self.input_size,
 | |
|             self.output_size,
 | |
|         ]
 | |
|         if fd_config.quant_config:
 | |
|             self.quant_method.create_weights(self)
 | |
|         self.init_weight()
 | |
| 
 | |
|     def init_weight(self):
 | |
|         """
 | |
|         Initialize the weights and biases.
 | |
|         """
 | |
|         if self.skip_quant:
 | |
|             self.weight_dtype = self._dtype
 | |
|         self.weight = self.create_parameter(
 | |
|             shape=self.weight_shape,
 | |
|             dtype=self.weight_dtype,
 | |
|             is_bias=False,
 | |
|             default_initializer=paddle.nn.initializer.Constant(0),
 | |
|         )
 | |
|         if self.nranks > 0:
 | |
|             # col parallel
 | |
|             _set_var_distributed(self.weight, split_axis=1)
 | |
| 
 | |
|         self.bias = None
 | |
|         if self.with_bias:
 | |
|             self.bias = self.create_parameter(
 | |
|                 shape=[self.output_size],
 | |
|                 dtype=self._dtype,
 | |
|                 is_bias=True,
 | |
|             )
 | |
|             if self.nranks > 0:
 | |
|                 # col parallel
 | |
|                 _set_var_distributed(self.bias, split_axis=1)
 | |
| 
 | |
|         # smooth quant
 | |
|         self.linear_shift = None
 | |
|         self.linear_smooth = None
 | |
| 
 | |
| 
 | |
| class MergedColumnParallelLinear(ColumnParallelLinear):
 | |
|     """
 | |
|     MergedColumnParallelLinear Layer.
 | |
| 
 | |
|     Similar to ColumnParallelLinear, but the weight matrix is concatenated
 | |
|     along the output dimension. When the weight matrix is loaded, the
 | |
|     different partitions are sharded separately.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         prefix: str,
 | |
|         input_size: int = None,
 | |
|         output_size: int = None,
 | |
|         with_bias: bool = False,
 | |
|         add_bias: bool = False,
 | |
|         activation: str = "gelu",
 | |
|         skip_quant: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initialize the fused up_gate_proj Linear layer with given parameters.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|                 Can be arbitrarily named.
 | |
|             input_size (int): Number of input features. Defaults to None.
 | |
|             output_size (int): Number of output features. Defaults to None.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
 | |
|             activation (str): Activation function to use. Defaults to "gelu".
 | |
|             skip_quant (bool): Whether to skip quantization. Defaults to False.
 | |
|         """
 | |
|         self.activation = activation
 | |
|         self.hidden_size = fd_config.model_config.hidden_size
 | |
|         self.nranks = fd_config.parallel_config.tensor_parallel_size
 | |
| 
 | |
|         super().__init__(fd_config=fd_config,
 | |
|                          prefix=prefix,
 | |
|                          input_size=input_size,
 | |
|                          output_size=output_size,
 | |
|                          with_bias=with_bias,
 | |
|                          add_bias=add_bias,
 | |
|                          skip_quant=skip_quant)
 | |
| 
 | |
|     def load_state_dict(self, state_dict: dict):
 | |
|         """
 | |
|         Load the checkpoint state dictionary into the layer.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the checkpoint weights and biases.
 | |
|         """
 | |
|         # weight
 | |
|         assert self.weight_key is not None, 'weight_key should not be None.'
 | |
|         if self.weight_key in state_dict.keys():
 | |
|             weight_tensor = get_tensor(state_dict.pop(self.weight_key))
 | |
|         else:
 | |
|             gate_weight_key = self.weight_key.replace("up_gate_proj",
 | |
|                                                       "gate_proj")
 | |
|             up_weight_key = self.weight_key.replace("up_gate_proj", "up_proj")
 | |
|             gate_tensor = get_tensor(state_dict.pop(gate_weight_key))
 | |
|             up_tensor = get_tensor(state_dict.pop(up_weight_key))
 | |
|             weight_tensor = paddle.concat([gate_tensor, up_tensor], axis=-1)
 | |
| 
 | |
|             if self.with_bias:
 | |
|                 gate_bias_key = self.bias_key.replace("up_gate_proj",
 | |
|                                                       "gate_proj")
 | |
|                 bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
 | |
|                     paddle.get_default_dtype())
 | |
| 
 | |
|                 state_dict[self.bias_key] = bias_tensor
 | |
| 
 | |
|         state_dict[self.weight_key] = weight_tensor
 | |
| 
 | |
|         super().load_state_dict(state_dict)
 | |
| 
 | |
| 
 | |
| class QKVParallelLinear(ColumnParallelLinear):
 | |
|     """
 | |
|     QKVParallelLinear Layer.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
 | |
|         """
 | |
|         Initialize the QKV Linear layer with given parameters.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|                 Can be arbitrarily named.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
 | |
|         """
 | |
|         self.num_heads = fd_config.model_config.num_attention_heads
 | |
|         self.kv_num_heads = fd_config.model_config.num_key_value_heads
 | |
|         self.hidden_size = fd_config.model_config.hidden_size
 | |
|         self.head_dim = fd_config.model_config.head_dim
 | |
|         self.nranks = fd_config.parallel_config.tensor_parallel_size
 | |
|         self.num_heads_per_rank = divide(self.num_heads, self.nranks)
 | |
|         if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
 | |
|             self.kv_num_heads_per_rank = 1
 | |
|             output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
 | |
|         else:
 | |
|             self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
 | |
|             output_size = (self.num_heads +
 | |
|                            2 * self.kv_num_heads) * self.head_dim
 | |
|         input_size = self.hidden_size
 | |
|         super().__init__(fd_config=fd_config,
 | |
|                          prefix=prefix,
 | |
|                          input_size=input_size,
 | |
|                          output_size=output_size,
 | |
|                          with_bias=with_bias,
 | |
|                          add_bias=add_bias)
 | |
| 
 | |
|     def load_weight(self, state_dict: dict):
 | |
|         """
 | |
|         Load the weight from the state dictionary.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the weights
 | |
|         """
 | |
|         if self.weight_key in state_dict.keys():
 | |
|             weight_tensor = get_tensor(state_dict.pop(self.weight_key))
 | |
|         else:
 | |
|             q_weight_key = self.weight_key.replace("qkv_proj", "q_proj")
 | |
|             k_weight_key = self.weight_key.replace("qkv_proj", "k_proj")
 | |
|             v_weight_key = self.weight_key.replace("qkv_proj", "v_proj")
 | |
|             q_tensor = get_tensor(state_dict.pop(q_weight_key))
 | |
|             k_tensor = get_tensor(state_dict.pop(k_weight_key))
 | |
|             v_tensor = get_tensor(state_dict.pop(v_weight_key))
 | |
| 
 | |
|             if self.kv_num_heads < self.nranks:
 | |
|                 sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
 | |
|                 sharedkv_start = sharedkv_index * self.head_dim
 | |
|                 sharedkv_end = sharedkv_start + self.head_dim
 | |
|                 k_tensor = k_tensor[ : , sharedkv_start : sharedkv_end]
 | |
|                 v_tensor = v_tensor[ : , sharedkv_start : sharedkv_end]
 | |
|             weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor],
 | |
|                                           axis=-1).transpose([1, 0])
 | |
|             weight_tensor = weight_tensor.reshape([
 | |
|                 (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) *
 | |
|                 (self.head_dim),
 | |
|                 self.hidden_size,
 | |
|             ])
 | |
|             weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
 | |
| 
 | |
|         if self.fd_config.quant_config:
 | |
|             self.quant_method.process_loaded_weights(self, weight_tensor)
 | |
|         else:
 | |
|             self.weight.set_value(weight_tensor)
 | |
| 
 | |
|     def load_state_dict(self, state_dict: dict):
 | |
|         """
 | |
|         Load the checkpoint state dictionary into the layer.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the checkpoint weights and biases.
 | |
|         """
 | |
|         # weight
 | |
|         assert self.weight_key is not None, 'weight_key should not be None.'
 | |
|         # qkv fused in disk
 | |
| 
 | |
|         if self.fd_config.model_config.is_quantized:
 | |
|             self.load_prequant_weight(state_dict)
 | |
|         else:
 | |
|             self.load_weight(state_dict)
 | |
| 
 | |
|         # bias
 | |
|         if self.with_bias:
 | |
|             if self.bias_key in state_dict.keys():
 | |
|                 bias_tensor = paddle.to_tensor(
 | |
|                     get_tensor(state_dict.pop(self.bias_key)))
 | |
|                 self.bias.set_value(bias_tensor)
 | |
|             else:
 | |
|                 q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
 | |
|                 k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
 | |
|                 v_bias_key = self.bias_key.replace("qkv_proj", "v_proj")
 | |
|                 q_bias = get_tensor(state_dict.pop(q_bias_key))
 | |
|                 k_bias = get_tensor(state_dict.pop(k_bias_key))
 | |
|                 v_bias = get_tensor(state_dict.pop(v_bias_key))
 | |
|                 qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
 | |
|             self.bias.set_value(qkv_bias)
 | |
| 
 | |
| 
 | |
| class RowParallelLinear(LinearBase):
 | |
|     """
 | |
|     RowParallelLinear Layer.
 | |
| 
 | |
|     The linear layer is defined as Y = XA + b. A is parallelized along
 | |
|     its first dimension and X along its second dimension as:
 | |
|                -   -
 | |
|               | A_1 |
 | |
|               | .   |
 | |
|           A = | .   |        X = [X_1, ..., X_p]
 | |
|               | .   |
 | |
|               | A_p |
 | |
|                -   -
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         prefix: str = "",
 | |
|         input_size: int = None,
 | |
|         output_size: int = None,
 | |
|         with_bias: bool = False,
 | |
|         add_bias: bool = False,
 | |
|         reduce_results: bool = True,
 | |
|         skip_quant: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initialize a linear layer with additional parameters for inference and quantization.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|                 Can be arbitrarily named.
 | |
|             input_size (int): Number of input features. Defaults to None.
 | |
|             output_size (int): Number of output features. Defaults to None.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
 | |
|             skip_quant (bool): Whether to skip quantization. Defaults to False.
 | |
|         """
 | |
|         super().__init__(fd_config=fd_config,
 | |
|                          prefix=prefix,
 | |
|                          input_size=input_size,
 | |
|                          output_size=output_size,
 | |
|                          with_bias=with_bias,
 | |
|                          add_bias=add_bias,
 | |
|                          skip_quant=skip_quant)
 | |
|         self.fd_config = fd_config
 | |
|         self.skip_quant = False
 | |
|         self.nranks = fd_config.parallel_config.tensor_parallel_size
 | |
|         self.hidden_size = fd_config.model_config.hidden_size
 | |
|         self.head_dim = fd_config.model_config.head_dim
 | |
|         self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
 | |
| 
 | |
|         # Split input_size when using TP inference.
 | |
|         self.input_size = divide(input_size, self.nranks)
 | |
|         self.output_size = output_size
 | |
| 
 | |
|         self.weight_shape = [
 | |
|             self.input_size,
 | |
|             self.output_size,
 | |
|         ]
 | |
|         self._dtype = self._helper.get_default_dtype()
 | |
| 
 | |
|         if fd_config.quant_config:
 | |
|             self.quant_method = fd_config.quant_config.get_quant_method(self)
 | |
|             self.quant_method.create_weights(self)
 | |
| 
 | |
|         self.reduce_results = reduce_results
 | |
|         self.init_weight()
 | |
| 
 | |
|     def init_weight(self):
 | |
|         """
 | |
|         Initialize the weights and biases.
 | |
|         """
 | |
|         if self.skip_quant:
 | |
|             self.weight_dtype = self._dtype
 | |
| 
 | |
|         self.weight = self.create_parameter(
 | |
|             shape=self.weight_shape,
 | |
|             dtype=self.weight_dtype,
 | |
|             is_bias=False,
 | |
|             default_initializer=paddle.nn.initializer.Constant(0),
 | |
|         )
 | |
| 
 | |
|         self.bias = None
 | |
|         if self.with_bias:
 | |
|             self.bias = self.create_parameter(
 | |
|                 shape=[self.hidden_size],
 | |
|                 dtype=self._dtype,
 | |
|                 is_bias=True,
 | |
|             )
 | |
| 
 | |
|         if self.nranks > 0:
 | |
|             # row parallel
 | |
|             _set_var_distributed(self.weight, split_axis=0)
 | |
| 
 | |
|         # smooth quant
 | |
|         self.linear_shift = None
 | |
|         self.linear_smooth = None
 | |
| 
 | |
|     def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
 | |
|         if self.fd_config.quant_config:
 | |
|             out = self.quant_method.apply(self, x)
 | |
|         else:
 | |
|             out = paddle.matmul(x, self.weight)
 | |
| 
 | |
|         if self.reduce_results and self.nranks > 1:
 | |
|             tensor_model_parallel_all_reduce(out)
 | |
| 
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class KVBatchLinear(LinearBase):
 | |
|     """
 | |
|     KVBatchLinear Layer for handling combined KV projections with bmm.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         prefix: str = "",
 | |
|         kv_lora_rank: int = None,
 | |
|         num_attention_heads: int = None,
 | |
|         qk_nope_head_dim: int = None,
 | |
|         v_head_dim: int = None,
 | |
|         with_bias: bool = False,
 | |
|         skip_quant: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initializes a KV batch linear layer that internally splits into K and V projections.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Inference-related parameters.
 | |
|             prefix (str): Unique name of the layer, used to name internal attributes.
 | |
|             kv_lora_rank (int): LoRA rank for KV projection. Defaults to None.
 | |
|             num_attention_heads (int): Number of attention heads. Defaults to None.
 | |
|             qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
 | |
|             v_head_dim (int): Dimension for V projection. Defaults to None.
 | |
|             with_bias (bool): Whether to include bias or not. Defaults to False.
 | |
|             skip_quant (bool): Whether to skip quantization. Defaults to False.
 | |
|         """
 | |
|         self.nranks = fd_config.parallel_config.tensor_parallel_size
 | |
|         self.kv_lora_rank = kv_lora_rank
 | |
|         self.num_attention_heads = num_attention_heads
 | |
|         self.qk_nope_head_dim = qk_nope_head_dim
 | |
|         self.v_head_dim = v_head_dim
 | |
|         # Split num_attention_heads when using TP inference.
 | |
|         self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
 | |
| 
 | |
|         # Initialize parent with combined dimensions
 | |
|         super().__init__(
 | |
|             fd_config=fd_config,
 | |
|             prefix=prefix,
 | |
|             input_size=None,  # Will be determined from weight shape
 | |
|             output_size=None,  # Will be determined from weight shape
 | |
|             with_bias=with_bias,
 | |
|             add_bias=False,
 | |
|             skip_quant=skip_quant,
 | |
|         )
 | |
|         self.weight_dtype = self._dtype
 | |
| 
 | |
|         # Override weight keys to use the combined kv_b_proj
 | |
|         self.weight_key = f"{prefix}.weight"  # e.g., "kv_b_proj.weight"
 | |
|         self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight"
 | |
|         self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight"
 | |
| 
 | |
|     def load_state_dict(self, state_dict: dict):
 | |
|         """
 | |
|         Load the combined KV weight and split it into K and V projections
 | |
|         """
 | |
|         # Get the combined KV weight
 | |
|         # NOTE(Ryan):Do not pop weight_key here, it will be popped in other class
 | |
|         kv_weight_tensor = get_tensor(state_dict[self.weight_key])
 | |
| 
 | |
|         # Reshape and split the weight
 | |
|         w = kv_weight_tensor.reshape([
 | |
|             self.kv_lora_rank,
 | |
|             self.num_heads_per_partition,
 | |
|             -1,
 | |
|         ]).transpose(perm=[1, 2, 0])
 | |
| 
 | |
|         # Split into K and V weights
 | |
|         # wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
 | |
|         wk_b = w[:, :self.qk_nope_head_dim, :]
 | |
| 
 | |
|         if self.v_head_dim is None:
 | |
|             raise ValueError("self.v_head_dim should not be None")
 | |
|         # wv_b: [num_heads, kv_lora_rank, v_head_dim]
 | |
|         wv_b = w[:, -self.v_head_dim:, :].transpose(perm=[0, 2, 1])
 | |
| 
 | |
|         # Create K projection weight
 | |
|         self.k_b_proj_weight = self.create_parameter(
 | |
|             shape=wk_b.shape,
 | |
|             dtype=self.weight_dtype,
 | |
|             is_bias=False,
 | |
|             default_initializer=paddle.nn.initializer.Constant(0),
 | |
|         )
 | |
| 
 | |
|         # Create V projection weight
 | |
|         self.v_b_proj_weight = self.create_parameter(
 | |
|             shape=wv_b.shape,
 | |
|             dtype=self.weight_dtype,
 | |
|             is_bias=False,
 | |
|             default_initializer=paddle.nn.initializer.Constant(0),
 | |
|         )
 | |
| 
 | |
|         self.k_b_proj_weight.set_value(wk_b)
 | |
|         self.v_b_proj_weight.set_value(wv_b)
 | |
| 
 | |
|     def forward_k_b(self, x: paddle.Tensor) -> paddle.Tensor:
 | |
|         """
 | |
|         Forward pass for K_b projection using bmm
 | |
| 
 | |
|         Args:
 | |
|             x: Input tensor (e.g., query_nope.transpose([1, 0, 2]))
 | |
| 
 | |
|         Returns:
 | |
|             K_b projection output
 | |
|         """
 | |
| 
 | |
|         out = paddle.bmm(x, self.k_b_proj_weight)
 | |
|         return out
 | |
| 
 | |
|     def forward_v_b(self, x: paddle.Tensor) -> paddle.Tensor:
 | |
|         """
 | |
|         Forward pass for V_b projection using bmm
 | |
| 
 | |
|         Args:
 | |
|             x: Input tensor (e.g., fmha_out_decode)
 | |
| 
 | |
|         Returns:
 | |
|             V_b projection output
 | |
|         """
 | |
|         out = paddle.bmm(x, self.v_b_proj_weight)
 | |
|         return out
 | |
| 
 | |
|     def forward_cuda(self,
 | |
|                      x: paddle.Tensor,
 | |
|                      proj_type: str = 'k') -> paddle.Tensor:
 | |
|         """
 | |
|         Forward function that can handle both K and V projections
 | |
| 
 | |
|         Args:
 | |
|             x: Input tensor
 | |
|             proj_type: 'k' or 'v' to select which projection to use
 | |
| 
 | |
|         Returns:
 | |
|             Projection output
 | |
|         """
 | |
|         if proj_type == 'k':
 | |
|             return self.forward_k_b(x)
 | |
|         elif proj_type == 'v':
 | |
|             return self.forward_v_b(x)
 | |
|         else:
 | |
|             raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")
 |