mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -68,13 +68,13 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.word_embeddings = nn.Embedding(
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
)
|
||||
else:
|
||||
if not self.column_cut:
|
||||
self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
@@ -85,13 +85,13 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
)
|
||||
else:
|
||||
# column cut embedding
|
||||
self.word_embeddings = nn.Embedding(
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
|
||||
self.word_embeddings.weight.is_distributed = True
|
||||
self.word_embeddings.weight.split_axis = 1
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
|
||||
if not self.use_rope:
|
||||
self.position_embeddings = nn.Embedding(
|
||||
@@ -112,13 +112,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
a = state_dict[self.prefix + ".weight"]
|
||||
if self.tie_word_embeddings:
|
||||
self.word_embeddings.weight.set_value(
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
self.word_embeddings.weight.set_value(
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
|
||||
@@ -134,10 +133,10 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Tensor: Embedded tensor representation of the input IDs.
|
||||
"""
|
||||
if self.use_ep:
|
||||
input_embedings = self.word_embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
else:
|
||||
if self.column_cut:
|
||||
input_embedings = self.word_embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
@@ -148,6 +147,6 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
input_embedings = self.word_embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
|
||||
return input_embedings
|
||||
|
Reference in New Issue
Block a user