add tie_word_embeddings for lmhead (#4916)

This commit is contained in:
Ryan
2025-11-11 10:46:35 +08:00
committed by GitHub
parent 3f74281496
commit 07a82afcae

View File

@@ -311,7 +311,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
self.qwen2 = Qwen2Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
self.lm_head = ParallelLMHead(
fd_config=fd_config,
embedding_dim=fd_config.model_config.hidden_size,
@@ -376,6 +376,8 @@ class Qwen2ForCausalLM(ModelForCasualLM):
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.load_state_dict({self.lm_head.weight_key: self.qwen2.embed_tokens.embeddings.weight})
@classmethod
def name(self):