[bugfix]fix blockwisefp8 and all_reduce (#3243)

* fix

* update

* fix linear for prequant loader
This commit is contained in:
bukejiyu
2025-08-06 23:54:33 +08:00
committed by GitHub
parent 3a15e0c53e
commit 9408e667a5
4 changed files with 37 additions and 24 deletions

View File

@@ -81,7 +81,8 @@ class VocabParallelEmbedding(nn.Layer):
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range), initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
), ),
) )
set_weight_attrs(self.embeddings.weight, {"output_dim": False}) if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else: else:
# column cut embedding # column cut embedding
self.embeddings = nn.Embedding( self.embeddings = nn.Embedding(
@@ -91,7 +92,8 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.is_distributed = True self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1 self.embeddings.weight.split_axis = 1
set_weight_attrs(self.embeddings.weight, {"output_dim": True}) if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob) self.dropout = nn.Dropout(self.hidden_dropout_prob)

View File

@@ -37,7 +37,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs): def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
""" """
extra_weight_attrs is a dictionary that may include parameters like: extra_weight_attrs is a dictionary that may include parameters like:
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns) - output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data - weight_loader: a callable or method responsible for loading the weight data
""" """
@@ -51,9 +50,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
layer.weight, layer.weight,
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))}, {"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
) )
if hasattr(layer, "nranks") and layer.nranks > 0: if hasattr(layer, "nranks") and layer.nranks > 1:
split_axis = extra_weight_attrs.get("split_axis")
_set_var_distributed(layer.weight, split_axis=split_axis)
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")}) set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
@@ -125,6 +122,10 @@ class LinearBase(nn.Layer):
# key # key
if weight_key: if weight_key:
self.weight_key = f"{prefix}.{weight_key}" self.weight_key = f"{prefix}.{weight_key}"
elif fd_config.model_config.is_quantized and not skip_quant:
self.weight_key = f"{prefix}.quant_weight"
self.weight_scale_key = f"{prefix}.weight_scale"
self.act_scale_key = f"{prefix}.activation_scale"
else: else:
self.weight_key = f"{prefix}.weight" self.weight_key = f"{prefix}.weight"
self.bias_key = f"{prefix}.bias" self.bias_key = f"{prefix}.bias"
@@ -173,7 +174,11 @@ class LinearBase(nn.Layer):
Args: Args:
state_dict (dict): A dictionary containing the prequantized weights and scales. state_dict (dict): A dictionary containing the prequantized weights and scales.
""" """
self.quant_method.process_prequanted_weights(self, state_dict) if isinstance(self.quant_method, UnquantizedLinearMethod):
# for gate
self.load_weight(state_dict)
else:
self.quant_method.process_prequanted_weights(self, state_dict)
def load_weight(self, state_dict: dict): def load_weight(self, state_dict: dict):
""" """
@@ -333,18 +338,18 @@ class ColumnParallelLinear(LinearBase):
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
self, self,
split_axis=1,
output_dim=True, output_dim=True,
weight_loader=( weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
), ),
) )
if self.nranks > 0:
if self.with_bias: _set_var_distributed(self.weight, split_axis=1)
if self.nranks > 0: if self.with_bias:
# col parallel # col parallel
_set_var_distributed(self.bias, split_axis=1) _set_var_distributed(self.bias, split_axis=1)
set_weight_attrs(self.bias, {"output_dim": True}) if self.nranks > 1:
set_weight_attrs(self.bias, {"output_dim": True})
class MergedColumnParallelLinear(ColumnParallelLinear): class MergedColumnParallelLinear(ColumnParallelLinear):
@@ -669,15 +674,19 @@ class RowParallelLinear(LinearBase):
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
), ),
) )
if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=0)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
if self.nranks > 1:
set_weight_attrs(
self.bias,
{
"output_dim": False,
},
)
if self.with_bias:
_set_var_distributed(self.bias, split_axis=0)
set_weight_attrs(
self.bias,
{
"output_dim": False,
},
)
self.reduce_results = reduce_results self.reduce_results = reduce_results
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

View File

@@ -60,6 +60,7 @@ class ParallelLMHead(nn.Layer):
self.bias_key: Optional[str] = None self.bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -91,7 +92,8 @@ class ParallelLMHead(nn.Layer):
gather_output=need_gather, gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
set_weight_attrs(self.linear.weight, {"output_dim": True}) if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else: else:
self.linear = RowParallelLinear( self.linear = RowParallelLinear(
embedding_dim, embedding_dim,
@@ -102,7 +104,8 @@ class ParallelLMHead(nn.Layer):
input_is_parallel=False, input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
set_weight_attrs(self.linear.weight, {"output_dim": False}) if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
""" """

View File

@@ -83,7 +83,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
def create_weights(self, layer, **extra_weight_attrs): def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse() layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
layer.weight = layer.create_parameter( layer.weight = layer.create_parameter(
shape=layer.weight_shape, shape=layer.weight_shape,
dtype=layer.weight_dtype, dtype=layer.weight_dtype,
@@ -101,7 +101,6 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
dtype="float32", dtype="float32",
is_bias=False, is_bias=False,
) )
layer.weight_dtype = "float8_e4m3fn"
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
weight_tensor = weights.transpose([1, 0]) weight_tensor = weights.transpose([1, 0])