mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -14,16 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
from paddle.distributed.fleet.meta_parallel import (
|
||||
ColumnParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear,
|
||||
VocabParallelEmbedding)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -130,7 +127,7 @@ class HydraHead(nn.Layer):
|
||||
]
|
||||
)
|
||||
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
@@ -170,8 +167,8 @@ class HydraHead(nn.Layer):
|
||||
get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight"))
|
||||
)
|
||||
|
||||
self.word_embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop("word_embeddings.weight"))
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop("embeddings.weight"))
|
||||
)
|
||||
|
||||
def set_state_dict(self, state_dict):
|
||||
@@ -183,7 +180,7 @@ class HydraHead(nn.Layer):
|
||||
"""
|
||||
is_custom = True
|
||||
for key in state_dict.keys():
|
||||
if key != "word_embeddings.weight" and (
|
||||
if key != "embeddings.weight" and (
|
||||
"hydra_mlp" in key or "hydra_head" in key
|
||||
):
|
||||
is_custom = False
|
||||
@@ -207,7 +204,7 @@ class HydraHead(nn.Layer):
|
||||
hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens
|
||||
"""
|
||||
hydra_inputs = [hidden_states]
|
||||
input_embeds = self.word_embeddings(input_ids)
|
||||
input_embeds = self.embeddings(input_ids)
|
||||
for hydra_head_idx in range(self.hydra_num_heads):
|
||||
hydra_inputs.append(input_embeds)
|
||||
head_input = paddle.concat(hydra_inputs, axis=-1)
|
||||
@@ -217,4 +214,4 @@ class HydraHead(nn.Layer):
|
||||
_, topk_tokens = paddle.topk(probs, k=1, axis=-1)
|
||||
next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:]
|
||||
|
||||
input_embeds = self.word_embeddings(next_tokens[:, 1 + hydra_head_idx])
|
||||
input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx])
|
||||
|
Reference in New Issue
Block a user