[Precision] Change lm_head layer running in float32 (#3596)

* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
This commit is contained in:
chen
2025-08-26 20:20:06 +08:00
committed by GitHub
parent 2136990144
commit d233e3c97c
14 changed files with 85 additions and 49 deletions

View File

@@ -257,14 +257,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
"""
self.model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits