[RL]Fix missing is_distributed attribute (#5150)

* fix

* update
This commit is contained in:
bukejiyu
2025-11-21 14:14:25 +08:00
committed by GitHub
parent 6ca2651995
commit 34f59d9800

View File

@@ -56,9 +56,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
split_axis = extra_weight_attrs.get("split_axis")
if hasattr(layer, "nranks") and layer.nranks > 0:
_set_var_distributed(layer.weight, split_axis=split_axis)
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
@@ -882,15 +879,18 @@ class RowParallelLinear(LinearBase):
if add_bias:
assert with_bias, "with_bias must be True when add_bias is True."
assert self.quant_method is not None
self.quant_method.create_weights(
self,
split_axis=0,
create_weight_kwargs = dict(
layer=self,
output_dim=None if self.split_token else False,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
)
if self.nranks > 0:
create_weight_kwargs["split_axis"] = 0
create_weight_kwargs["is_distributed"] = True
self.quant_method.create_weights(**create_weight_kwargs)
self.reduce_results = reduce_results