Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -28,18 +28,19 @@ class RMSNorm(nn.Layer):
def __init__(
self,
llm_config,
fd_config,
hidden_size,
eps=1e-5,
prefix="",
linear_bias=None,
quant_scale=None,
begin_norm_axis=1,
):
"""
Initializes the normalization layer.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
hidden_size (int) : size of hidden state.
@@ -52,7 +53,7 @@ class RMSNorm(nn.Layer):
NotImplementedError: If the specified norm_type is not supported.
"""
super().__init__()
self.llm_config = llm_config
self.fd_config = fd_config
self.prefix = prefix
self.hidden_size = hidden_size
if len(prefix) == 0:
@@ -66,6 +67,11 @@ class RMSNorm(nn.Layer):
self.quant_scale = quant_scale
self._dtype = self._helper.get_default_dtype()
self._norm_weight_dtype = self._dtype
self.begin_norm_axis = begin_norm_axis
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.begin_norm_axis = begin_norm_axis
self.init_weight()
@@ -118,13 +124,13 @@ class RMSNorm(nn.Layer):
norm_weight=self.ln_weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=1,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.llm_config.quant_config.quant_round_type,
quant_max_bound=self.llm_config.quant_config.quant_max_bound,
quant_min_bound=self.llm_config.quant_config.quant_min_bound,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0], norm_out[1]
@@ -139,7 +145,7 @@ class LayerNorm(nn.Layer):
def __init__(
self,
llm_config,
fd_config,
hidden_size,
eps=1e-5,
prefix="",
@@ -151,7 +157,7 @@ class LayerNorm(nn.Layer):
Initializes the normalization layer.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming internal attributes,
@@ -163,7 +169,7 @@ class LayerNorm(nn.Layer):
NotImplementedError: If the specified norm_type is not supported.
"""
super().__init__()
self.llm_config = llm_config
self.fd_config = fd_config
self.prefix = prefix
self.hidden_size = hidden_size
if len(prefix) == 0:
@@ -180,6 +186,10 @@ class LayerNorm(nn.Layer):
self._dtype = self._helper.get_default_dtype()
self._norm_weight_dtype = "float32"
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.init_weight()
def init_weight(self):
@@ -240,6 +250,7 @@ class LayerNorm(nn.Layer):
The `residual_output` is the result of applying the normalization and possibly other
operations (like linear transformation) on the `residual_input`.
"""
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
@@ -249,9 +260,9 @@ class LayerNorm(nn.Layer):
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1,
quant_round_type=self.llm_config.quant_config.quant_round_type,
quant_max_bound=self.llm_config.quant_config.quant_max_bound,
quant_min_bound=self.llm_config.quant_config.quant_min_bound,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0], norm_out[1]