mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -75,11 +75,10 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(
|
||||
mean=0.0, std=self.initializer_range), ),
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# column cut embedding
|
||||
@@ -94,8 +93,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.prefix = prefix
|
||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -104,12 +102,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
"""
|
||||
if self.tie_word_embeddings:
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(
|
||||
paddle.get_default_dtype()))
|
||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
|
||||
)
|
||||
else:
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
|
||||
)
|
||||
|
||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -131,8 +129,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
|
Reference in New Issue
Block a user