[Bug fix] Fixed the garbled text issues in Qwen3-8B (#2737)

* fix qwen3.py

* update

* update lm_head tie_word_embeddings

* update tie_word_embeddings

* fix

* fix tie_word_embedding not in config.json

---------

Co-authored-by: lizexu <lizexu@baidu.com>
This commit is contained in:
lizexu123
2025-07-08 14:15:27 +08:00
committed by GitHub
parent d0f4d6ba3a
commit 525be243e7
2 changed files with 10 additions and 5 deletions

View File

@@ -164,7 +164,6 @@ class Qwen3Model(nn.Layer):
self.num_layers = fd_config.model_config.num_layers
fd_config.model_config.prefix_name = "model"
fd_config.model_config.tie_word_embeddings = True
self.embeddings = VocabParallelEmbedding(
fd_config=fd_config,
@@ -240,14 +239,13 @@ class Qwen3ForCausalLM(ModelForCasualLM):
self.model = Qwen3Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
self.lm_head = ParallelLMHead(
fd_config=fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"),
prefix="lm_head",
)
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
@classmethod
def name(self):
@@ -269,7 +267,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
if self.tie_word_embeddings:
self.lm_head.out_linear.weight.set_value(
self.model.embeddings.word_embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict(state_dict)
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
"""
@@ -324,6 +323,7 @@ class Qwen3PretrainedModel(PretrainedModel):
base_actions = {
# Row Linear
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn,
is_column=False),