polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -69,11 +69,9 @@ class ParallelEHProjection(nn.Layer):
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.bias_key is not None else False,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
@@ -81,11 +79,9 @@ class ParallelEHProjection(nn.Layer):
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.bias_key is not None else False,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
@@ -99,20 +95,15 @@ class ParallelEHProjection(nn.Layer):
"""
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()))
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(
state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype())
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
paddle.get_default_dtype())
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
self.linear.bias.set_value(bias)
def forward(self, input):