mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
adaptive rms_norm's dtype (#3617)
* adaptive rms_norm's dtype * adaptive rms_norm's dtype * add approve coverage --------- Co-authored-by: liuyuanle <liuyuanle@baidu.com>
This commit is contained in:
2
.github/workflows/check-bypass.yml
vendored
2
.github/workflows/check-bypass.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
CI_TEAM_MEMBERS: '["YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
|
||||
CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
|
||||
outputs:
|
||||
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
|
||||
steps:
|
||||
|
@@ -46,6 +46,7 @@ class RMSNorm(nn.Layer):
|
||||
bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
begin_norm_axis: int = 1,
|
||||
dtype: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the RMSNormalization layer.
|
||||
@@ -80,8 +81,17 @@ class RMSNorm(nn.Layer):
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.bias: Optional[paddle.Tensor] = bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = self._dtype
|
||||
|
||||
self._norm_weight_dtype = dtype
|
||||
if self._norm_weight_dtype is None:
|
||||
self._norm_weight_dtype = self._helper.get_default_dtype()
|
||||
else:
|
||||
assert dtype in [
|
||||
"float32",
|
||||
"bfloat16",
|
||||
"float16",
|
||||
], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16"
|
||||
|
||||
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
@@ -111,8 +121,8 @@ class RMSNorm(nn.Layer):
|
||||
"""
|
||||
|
||||
# weight
|
||||
weight_tensor = paddle.cast(get_tensor(state_dict.pop(self.weight_key)), self._norm_weight_dtype)
|
||||
self.weight.set_value(weight_tensor)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
self.weight.set_value(weight_tensor.astype(self._norm_weight_dtype))
|
||||
|
||||
def forward(self, x, residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -131,9 +141,15 @@ class RMSNorm(nn.Layer):
|
||||
The `residual_output` is the result of applying the normalization and possibly other
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
x = x.astype(self.weight.dtype)
|
||||
if residual_input is not None:
|
||||
residual_input_dtype = residual_input.dtype
|
||||
residual_input = residual_input.astype(self.weight.dtype)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
norm_out = rms_norm(x, self.weight, self.eps)
|
||||
return norm_out.astype(x_dtype)
|
||||
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
@@ -150,9 +166,9 @@ class RMSNorm(nn.Layer):
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype)
|
||||
else:
|
||||
return norm_out[0]
|
||||
return norm_out[0].astype(x_dtype)
|
||||
|
||||
|
||||
class LayerNorm(nn.Layer):
|
||||
@@ -205,7 +221,6 @@ class LayerNorm(nn.Layer):
|
||||
else:
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.bias: Optional[paddle.Tensor] = bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
|
||||
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
|
Reference in New Issue
Block a user