[BugFix] fix RMSNorm rms_norm_esp (#2804)

This commit is contained in:
lizexu123
2025-07-10 20:39:02 +08:00
committed by GitHub
parent 823a47e64a
commit e681e1e719
7 changed files with 18 additions and 16 deletions

View File

@@ -84,6 +84,7 @@ class ModelConfig(PretrainedConfig):
head_dim: Optional[int] = None, head_dim: Optional[int] = None,
tie_word_embeddings: bool = False, tie_word_embeddings: bool = False,
is_quantized: bool = False, is_quantized: bool = False,
rms_norm_eps: float = 1e-5,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -123,6 +124,7 @@ class ModelConfig(PretrainedConfig):
self.dtype = dtype self.dtype = dtype
self.tie_word_embeddings = tie_word_embeddings self.tie_word_embeddings = tie_word_embeddings
self.is_quantized = is_quantized self.is_quantized = is_quantized
self.rms_norm_eps = rms_norm_eps
@dataclass @dataclass

View File

@@ -288,14 +288,14 @@ class Ernie4_5_DecoderLayer(nn.Layer):
self.input_layernorm = RMSNorm( self.input_layernorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
) )
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
) )
@@ -366,7 +366,7 @@ class Ernie4_5_Model(nn.Layer):
self.norm = RMSNorm( self.norm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.prefix_name}.norm", prefix=f"{fd_config.model_config.prefix_name}.norm",
) )

View File

@@ -275,14 +275,14 @@ class Ernie4_5_MTPModel(nn.Layer):
self.enorm = RMSNorm( self.enorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix="ernie.mtp_emb_norm.0", prefix="ernie.mtp_emb_norm.0",
) )
self.hnorm = RMSNorm( self.hnorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix="ernie.mtp_hidden_norm.0", prefix="ernie.mtp_hidden_norm.0",
) )

View File

@@ -271,14 +271,14 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
self.input_layernorm = RMSNorm( self.input_layernorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
) )
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
) )
@@ -355,7 +355,7 @@ class Ernie4_5_VLModel(nn.Layer):
self.norm = RMSNorm( self.norm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.prefix_name}.norm", prefix=f"{fd_config.model_config.prefix_name}.norm",
) )

View File

@@ -161,14 +161,14 @@ class Qwen2DecoderLayer(nn.Layer):
self.input_layernorm = RMSNorm( self.input_layernorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
) )
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
) )
@@ -248,7 +248,7 @@ class Qwen2Model(nn.Layer):
self.norm = RMSNorm( self.norm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-5, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.prefix_name}.norm", prefix=f"{fd_config.model_config.prefix_name}.norm",
) )

View File

@@ -79,12 +79,12 @@ class Qwen3Attention(nn.Layer):
self.q_norm = RMSNorm(fd_config=fd_config, self.q_norm = RMSNorm(fd_config=fd_config,
hidden_size=fd_config.model_config.head_dim, hidden_size=fd_config.model_config.head_dim,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm", prefix=f"{prefix}.q_norm",
begin_norm_axis=2) begin_norm_axis=2)
self.k_norm = RMSNorm(fd_config=fd_config, self.k_norm = RMSNorm(fd_config=fd_config,
hidden_size=fd_config.model_config.head_dim, hidden_size=fd_config.model_config.head_dim,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm", prefix=f"{prefix}.k_norm",
begin_norm_axis=2) begin_norm_axis=2)
@@ -184,7 +184,7 @@ class Qwen3Model(nn.Layer):
self.norm = RMSNorm( self.norm = RMSNorm(
fd_config, fd_config,
hidden_size=fd_config.model_config.hidden_size, hidden_size=fd_config.model_config.hidden_size,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.prefix_name}.norm", prefix=f"{fd_config.model_config.prefix_name}.norm",
) )

View File

@@ -121,12 +121,12 @@ class Qwen3Attention(nn.Layer):
self.q_norm = RMSNorm(fd_config, self.q_norm = RMSNorm(fd_config,
hidden_size=self.head_dim, hidden_size=self.head_dim,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm", prefix=f"{prefix}.q_norm",
begin_norm_axis=2) begin_norm_axis=2)
self.k_norm = RMSNorm(fd_config, self.k_norm = RMSNorm(fd_config,
hidden_size=self.head_dim, hidden_size=self.head_dim,
eps=1e-6, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm", prefix=f"{prefix}.k_norm",
begin_norm_axis=2) begin_norm_axis=2)