mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -69,11 +69,9 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.bias_key is not None else False,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
@@ -81,11 +79,9 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.bias_key is not None else False,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
@@ -99,20 +95,15 @@ class ParallelEHProjection(nn.Layer):
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(
|
||||
state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
|
Reference in New Issue
Block a user