mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[bugfix]fix blockwisefp8 and all_reduce (#3243)
* fix * update * fix linear for prequant loader
This commit is contained in:
@@ -60,6 +60,7 @@ class ParallelLMHead(nn.Layer):
|
||||
self.bias_key: Optional[str] = None
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
@@ -91,7 +92,8 @@ class ParallelLMHead(nn.Layer):
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
@@ -102,7 +104,8 @@ class ParallelLMHead(nn.Layer):
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user