polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -75,11 +75,10 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
embedding_dim,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range), ),
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
else:
# column cut embedding
@@ -94,8 +93,7 @@ class VocabParallelEmbedding(nn.Layer):
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
@@ -104,12 +102,12 @@ class VocabParallelEmbedding(nn.Layer):
"""
if self.tie_word_embeddings:
self.embeddings.weight.set_value(
get_tensor(state_dict[self.prefix + ".weight"]).astype(
paddle.get_default_dtype()))
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
)
else:
self.embeddings.weight.set_value(
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
paddle.get_default_dtype()))
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
@@ -131,8 +129,7 @@ class VocabParallelEmbedding(nn.Layer):
paddle.distributed.all_gather(
inputs_embeds_temp,
input_embedings,
group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
sync_op=True,
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)