mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-26 18:10:32 +08:00 
			
		
		
		
	 56e2d7e668
			
		
	
	56e2d7e668
	
	
	
		
			
			* adaptive rms_norm's dtype * adaptive rms_norm's dtype * add approve coverage --------- Co-authored-by: liuyuanle <liuyuanle@baidu.com>
		
			
				
	
	
		
			340 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			340 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| from typing import Callable, Dict, Optional
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| from paddle import nn
 | |
| 
 | |
| from fastdeploy.platforms import current_platform
 | |
| 
 | |
| if current_platform.is_gcu():
 | |
|     from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
 | |
| else:
 | |
|     from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
 | |
| 
 | |
| from fastdeploy.config import FDConfig
 | |
| 
 | |
| from .utils import get_tensor
 | |
| 
 | |
| 
 | |
| class RMSNorm(nn.Layer):
 | |
|     """
 | |
|     Normalization layer.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         hidden_size: int,
 | |
|         eps: float = 1e-5,
 | |
|         prefix: str = "",
 | |
|         bias: paddle.Tensor = None,
 | |
|         quant_scale: float = None,
 | |
|         begin_norm_axis: int = 1,
 | |
|         dtype: str = None,
 | |
|     ) -> None:
 | |
|         """
 | |
|         Initializes the RMSNormalization layer.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Arguments related to inference, containing
 | |
|                 attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
 | |
|                 num_attention_heads, and ffn_hidden_size.
 | |
|             hidden_size (int) : size of hidden state.
 | |
|             eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
 | |
|             prefix(str,optional):The name of current layer. Defaults to "".
 | |
|             bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
 | |
|             quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
 | |
|             begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
 | |
| 
 | |
|         Raises:
 | |
|             NotImplementedError: If the specified norm_type is not supported.
 | |
|         """
 | |
|         super().__init__()
 | |
|         self.fd_config = fd_config
 | |
|         self.prefix: str = prefix
 | |
|         self.hidden_size: int = hidden_size
 | |
|         if len(prefix) == 0:
 | |
|             self.weight_key: Optional[str] = None
 | |
|         else:
 | |
|             self.weight_key: Optional[str] = f"{prefix}.weight"
 | |
|         self.with_weight: bool = self.weight_key is not None
 | |
|         self.eps: float = eps
 | |
|         if current_platform.is_gcu():
 | |
|             self.norm_func: Callable = fused_add_rms_norm
 | |
|         else:
 | |
|             self.norm_func: Callable = fused_rms_norm
 | |
|         self.bias: Optional[paddle.Tensor] = bias
 | |
|         self.quant_scale: Optional[float] = quant_scale
 | |
| 
 | |
|         self._norm_weight_dtype = dtype
 | |
|         if self._norm_weight_dtype is None:
 | |
|             self._norm_weight_dtype = self._helper.get_default_dtype()
 | |
|         else:
 | |
|             assert dtype in [
 | |
|                 "float32",
 | |
|                 "bfloat16",
 | |
|                 "float16",
 | |
|             ], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16"
 | |
| 
 | |
|         self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
 | |
|         self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
 | |
|         self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
 | |
|         self.begin_norm_axis: int = begin_norm_axis
 | |
| 
 | |
|         self.init_weight()
 | |
| 
 | |
|     def init_weight(self):
 | |
|         """
 | |
|         Initialize the weights and biases.
 | |
|         """
 | |
| 
 | |
|         self.weight = None
 | |
|         if self.with_weight:
 | |
|             self.weight = self.create_parameter(
 | |
|                 shape=[self.hidden_size],
 | |
|                 default_initializer=nn.initializer.Constant(value=1.0),
 | |
|                 dtype=self._norm_weight_dtype,
 | |
|             )
 | |
| 
 | |
|     def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
 | |
|         """
 | |
|         Load the checkpoint state dictionary into the layer.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the checkpoint weights and biases.
 | |
|         """
 | |
| 
 | |
|         # weight
 | |
|         weight_tensor = get_tensor(state_dict.pop(self.weight_key))
 | |
|         self.weight.set_value(weight_tensor.astype(self._norm_weight_dtype))
 | |
| 
 | |
|     def forward(self, x, residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
 | |
|         """
 | |
|         Defines the forward computation of the layer.
 | |
| 
 | |
|         Args:
 | |
|             x (paddle.Tensor): Input tensor to be normalized.
 | |
|             residual_input (paddle.Tensor, optional): Residual input tensor for residual connection.
 | |
|                 Defaults to None. If provided, the normalization layer will also return the residual
 | |
|                 output for further computation.
 | |
| 
 | |
|         Returns:
 | |
|             paddle.Tensor or tuple of paddle.Tensor:
 | |
|                 - If `residual_input` is None, returns the normalized output tensor.
 | |
|                 - If `residual_input` is provided, returns a tuple of (normalized_output, residual_output).
 | |
|                   The `residual_output` is the result of applying the normalization and possibly other
 | |
|                   operations (like linear transformation) on the `residual_input`.
 | |
|         """
 | |
|         x_dtype = x.dtype
 | |
|         x = x.astype(self.weight.dtype)
 | |
|         if residual_input is not None:
 | |
|             residual_input_dtype = residual_input.dtype
 | |
|             residual_input = residual_input.astype(self.weight.dtype)
 | |
|         if current_platform.is_gcu():
 | |
|             if residual_input is None:
 | |
|                 norm_out = rms_norm(x, self.weight, self.eps)
 | |
|                 return norm_out.astype(x_dtype)
 | |
|             norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
 | |
|         else:
 | |
|             norm_out = self.norm_func(
 | |
|                 x,
 | |
|                 norm_weight=self.weight,
 | |
|                 norm_bias=None,
 | |
|                 epsilon=self.eps,
 | |
|                 begin_norm_axis=self.begin_norm_axis,
 | |
|                 bias=self.bias,
 | |
|                 residual=residual_input,
 | |
|                 quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
 | |
|                 quant_round_type=self.quant_round_type,
 | |
|                 quant_max_bound=self.quant_max_bound,
 | |
|                 quant_min_bound=self.quant_min_bound,
 | |
|             )
 | |
|         if residual_input is not None:
 | |
|             return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype)
 | |
|         else:
 | |
|             return norm_out[0].astype(x_dtype)
 | |
| 
 | |
| 
 | |
| class LayerNorm(nn.Layer):
 | |
|     """
 | |
|     Initializes the LayerNormalization layer
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         hidden_size: int,
 | |
|         eps: float = 1e-5,
 | |
|         prefix="",
 | |
|         bias: paddle.Tensor = None,
 | |
|         quant_scale: float = None,
 | |
|         with_bias: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initializes the normalization layer.
 | |
| 
 | |
|         Args:
 | |
|             fd_config (FDConfig): Arguments related to inference, containing
 | |
|                 attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
 | |
|                 num_attention_heads, and ffn_hidden_size.
 | |
|             hidden_size (int) : size of hidden state.
 | |
|             eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
 | |
|             prefix (str): Unique name of the layer, used for naming internal attributes,
 | |
|                 you can give it any name you like.
 | |
|             bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
 | |
|             quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
 | |
|             with_bias (bool):Whether to include bias or not. Defaults to False.
 | |
|         Raises:
 | |
|             NotImplementedError: If the specified norm_type is not supported.
 | |
|         """
 | |
|         super().__init__()
 | |
|         self.fd_config = fd_config
 | |
|         self.prefix: str = prefix
 | |
|         self.hidden_size: int = hidden_size
 | |
|         if len(prefix) == 0:
 | |
|             self.weight_key: Optional[str] = None
 | |
|         else:
 | |
|             self.weight_key: Optional[str] = f"{prefix}.weight"
 | |
|         self.with_weight: bool = self.weight_key is not None
 | |
|         self.bias_key: str = f"{prefix}.bias"
 | |
|         self.with_bias: bool = with_bias
 | |
|         self.eps: float = eps
 | |
|         self.quant_scale: float = quant_scale
 | |
|         if current_platform.is_gcu():
 | |
|             self.norm_func: Callable = paddle.nn.functional.layer_norm
 | |
|         else:
 | |
|             self.norm_func: Callable = fused_layer_norm
 | |
|         self.bias: Optional[paddle.Tensor] = bias
 | |
|         self._norm_weight_dtype: str = "float32"
 | |
| 
 | |
|         self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
 | |
|         self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
 | |
|         self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
 | |
| 
 | |
|         self.init_weight()
 | |
| 
 | |
|     def init_weight(self):
 | |
|         """
 | |
|         Initialize the weights and biases.
 | |
|         """
 | |
| 
 | |
|         self.weight = None
 | |
|         if self.with_weight:
 | |
|             self.weight = self.create_parameter(
 | |
|                 shape=[self.hidden_size],
 | |
|                 default_initializer=nn.initializer.Constant(value=1.0),
 | |
|                 dtype=self._norm_weight_dtype,
 | |
|             )
 | |
|         self.bias = None
 | |
|         if self.with_bias:
 | |
|             self.bias = self.create_parameter(
 | |
|                 shape=[self.hidden_size],
 | |
|                 is_bias=True,
 | |
|                 dtype=self._norm_weight_dtype,
 | |
|             )
 | |
| 
 | |
|     def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
 | |
|         """
 | |
|         Load the checkpoint state dictionary into the layer.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict): A dictionary containing the checkpoint weights and biases.
 | |
|         """
 | |
| 
 | |
|         # weight
 | |
|         weight_tensor = paddle.cast(get_tensor(state_dict.pop(self.weight_key)), self._norm_weight_dtype)
 | |
|         self.weight.set_value(weight_tensor)
 | |
| 
 | |
|         # bias
 | |
|         if self.with_bias:
 | |
|             bias_tensor = paddle.cast(
 | |
|                 get_tensor(state_dict.pop(self.bias_key)),
 | |
|                 self._norm_weight_dtype,
 | |
|             )
 | |
|             self.bias.set_value(bias_tensor)
 | |
| 
 | |
|     def forward(self, x, residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
 | |
|         """
 | |
|         Defines the forward computation of the layer.
 | |
| 
 | |
|         Args:
 | |
|             x (paddle.Tensor): Input tensor to be normalized.
 | |
|             residual_input (paddle.Tensor, optional): Residual input tensor for residual connection.
 | |
|                 Defaults to None. If provided, the normalization layer will also return the residual
 | |
|                 output for further computation.
 | |
| 
 | |
|         Returns:
 | |
|             paddle.Tensor or tuple of paddle.Tensor:
 | |
|                 - If `residual_input` is None, returns the normalized output tensor.
 | |
|                 - If `residual_input` is provided, returns a tuple of (normalized_output, residual_output).
 | |
|                   The `residual_output` is the result of applying the normalization and possibly other
 | |
|                   operations (like linear transformation) on the `residual_input`.
 | |
|         """
 | |
|         if current_platform.is_iluvatar():
 | |
|             if self.weight is None and self.bias is None:
 | |
|                 out = x
 | |
|                 if self.bias is not None:
 | |
|                     out += self.bias
 | |
|                 if residual_input is not None:
 | |
|                     out += residual_input
 | |
|                     return out, out
 | |
|                 else:
 | |
|                     return out
 | |
|             else:
 | |
|                 raise NotImplementedError("Iluvatar does not support yet!")
 | |
| 
 | |
|         if current_platform.is_gcu():
 | |
|             if residual_input is not None:
 | |
|                 y = x + residual_input
 | |
|                 out = self.norm_func(
 | |
|                     x=y,
 | |
|                     normalized_shape=y.shape[1:],
 | |
|                     weight=self.weight,
 | |
|                     bias=self.bias,
 | |
|                     epsilon=self.eps,
 | |
|                 )
 | |
|                 return out, y
 | |
|             else:
 | |
|                 out = self.norm_func(
 | |
|                     x=x,
 | |
|                     normalized_shape=x.shape[1:],
 | |
|                     weight=self.weight,
 | |
|                     bias=self.bias,
 | |
|                     epsilon=self.eps,
 | |
|                 )
 | |
|                 return out
 | |
|         else:
 | |
|             norm_out = self.norm_func(
 | |
|                 x,
 | |
|                 norm_weight=self.weight,
 | |
|                 norm_bias=self.bias,
 | |
|                 epsilon=self.eps,
 | |
|                 begin_norm_axis=1,
 | |
|                 bias=self.bias,
 | |
|                 residual=residual_input,
 | |
|                 quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
 | |
|                 quant_round_type=self.quant_round_type,
 | |
|                 quant_max_bound=self.quant_max_bound,
 | |
|                 quant_min_bound=self.quant_min_bound,
 | |
|             )
 | |
|         if residual_input is not None:
 | |
|             return norm_out[0], norm_out[1]
 | |
|         else:
 | |
|             return norm_out[0]
 |