mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[GCU] Support gcu platform (#2702)
baseline: e7fa57ebae
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -19,9 +19,14 @@ from typing import Callable, Dict, Optional
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
|
||||
else:
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
@@ -69,7 +74,10 @@ class RMSNorm(nn.Layer):
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.eps: float = eps
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
if current_platform.is_gcu():
|
||||
self.norm_func: Callable = fused_add_rms_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
@@ -129,19 +137,26 @@ class RMSNorm(nn.Layer):
|
||||
The `residual_output` is the result of applying the normalization and possibly other
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
return rms_norm(x, self.ln_weight, self.eps)
|
||||
norm_out = self.norm_func(
|
||||
x, residual_input, self.ln_weight, self.eps
|
||||
)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
else:
|
||||
@@ -193,7 +208,10 @@ class LayerNorm(nn.Layer):
|
||||
self.with_bias: bool = with_bias
|
||||
self.eps: float = eps
|
||||
self.quant_scale: float = quant_scale
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
if current_platform.is_gcu():
|
||||
self.norm_func: Callable = paddle.nn.functional.layer_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
@@ -279,19 +297,40 @@ class LayerNorm(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError("Iluvatar does not support yet!")
|
||||
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is not None:
|
||||
y = x + residual_input
|
||||
out = self.norm_func(
|
||||
x=y,
|
||||
normalized_shape=y.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out, y
|
||||
else:
|
||||
out = self.norm_func(
|
||||
x=x,
|
||||
normalized_shape=x.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user