mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
add tie_word_embeddings for lmhead (#4916)
This commit is contained in:
@@ -311,7 +311,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
self.qwen2 = Qwen2Model(fd_config=fd_config)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
self.lm_head = ParallelLMHead(
|
||||
fd_config=fd_config,
|
||||
embedding_dim=fd_config.model_config.hidden_size,
|
||||
@@ -376,6 +376,8 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
weight_loader(param, loaded_weight)
|
||||
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.qwen2.embed_tokens.embeddings.weight})
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
|
||||
Reference in New Issue
Block a user