mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -48,7 +48,7 @@ class Qwen2MLP(nn.Layer):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
self.up_gate_proj = MergedColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.up_gate_proj",
|
||||
input_size=fd_config.model_config.hidden_size,
|
||||
@@ -67,20 +67,20 @@ class Qwen2MLP(nn.Layer):
|
||||
|
||||
self.act_fn = SiluAndMul(
|
||||
fd_config=fd_config,
|
||||
bias=getattr(self.gate_up_proj, "linear_bias", None),
|
||||
bias=getattr(self.up_gate_proj, "bias", None),
|
||||
act_method=fd_config.model_config.hidden_act,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
self.gate_up_proj.load_state_dict(state_dict)
|
||||
self.up_gate_proj.load_state_dict(state_dict)
|
||||
self.down_proj.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
"""
|
||||
gate_up_out = self.gate_up_proj(x)
|
||||
gate_up_out = self.up_gate_proj(x)
|
||||
act_out = self.act_fn(gate_up_out)
|
||||
down_out = self.down_proj(act_out)
|
||||
return down_out
|
||||
@@ -230,7 +230,7 @@ class Qwen2Model(nn.Layer):
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
fd_config.model_config.pretrained_config.prefix_name = "qwen2"
|
||||
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
fd_config=fd_config,
|
||||
num_embeddings=fd_config.model_config.vocab_size,
|
||||
embedding_dim=fd_config.model_config.hidden_size,
|
||||
@@ -261,7 +261,7 @@ class Qwen2Model(nn.Layer):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
self.embeddings.load_state_dict(state_dict)
|
||||
self.embed_tokens.load_state_dict(state_dict)
|
||||
self.norm.load_state_dict(state_dict)
|
||||
for i in range(self.num_layers):
|
||||
logger.info(f"Start load layer {i}")
|
||||
@@ -275,7 +275,7 @@ class Qwen2Model(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
|
||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
residual = None
|
||||
|
||||
@@ -303,7 +303,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
super(Qwen2ForCausalLM, self).__init__(fd_config)
|
||||
|
||||
self.fd_config =fd_config
|
||||
self.model = Qwen2Model(fd_config=fd_config)
|
||||
self.qwen2 = Qwen2Model(fd_config=fd_config)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
||||
@@ -330,7 +330,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.qwen2.load_state_dict(state_dict)
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
@@ -349,7 +349,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
):
|
||||
"""
|
||||
"""
|
||||
hidden_states = self.model(ids_remove_padding=ids_remove_padding,
|
||||
hidden_states = self.qwen2(ids_remove_padding=ids_remove_padding,
|
||||
forward_meta=forward_meta)
|
||||
|
||||
return hidden_states
|
||||
|
Reference in New Issue
Block a user