[bugfix]fix blockwisefp8 and all_reduce (#3243)

* fix

* update

* fix linear for prequant loader
This commit is contained in:
bukejiyu
2025-08-06 23:54:33 +08:00
committed by GitHub
parent 3a15e0c53e
commit 9408e667a5
4 changed files with 37 additions and 24 deletions

View File

@@ -60,6 +60,7 @@ class ParallelLMHead(nn.Layer):
self.bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -91,7 +92,8 @@ class ParallelLMHead(nn.Layer):
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
@@ -102,7 +104,8 @@ class ParallelLMHead(nn.Layer):
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""