mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-21 15:49:31 +08:00
[bugfix]fix blockwisefp8 and all_reduce (#3243)
* fix * update * fix linear for prequant loader
This commit is contained in:
@@ -81,7 +81,8 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
if self.world_size > 1:
|
||||||
|
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||||
else:
|
else:
|
||||||
# column cut embedding
|
# column cut embedding
|
||||||
self.embeddings = nn.Embedding(
|
self.embeddings = nn.Embedding(
|
||||||
@@ -91,7 +92,8 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
|
|
||||||
self.embeddings.weight.is_distributed = True
|
self.embeddings.weight.is_distributed = True
|
||||||
self.embeddings.weight.split_axis = 1
|
self.embeddings.weight.split_axis = 1
|
||||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
if self.world_size > 1:
|
||||||
|
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||||
|
@@ -37,7 +37,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
|||||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
"""
|
"""
|
||||||
extra_weight_attrs is a dictionary that may include parameters like:
|
extra_weight_attrs is a dictionary that may include parameters like:
|
||||||
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
|
|
||||||
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
|
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
|
||||||
- weight_loader: a callable or method responsible for loading the weight data
|
- weight_loader: a callable or method responsible for loading the weight data
|
||||||
"""
|
"""
|
||||||
@@ -51,9 +50,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
|||||||
layer.weight,
|
layer.weight,
|
||||||
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
|
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
|
||||||
)
|
)
|
||||||
if hasattr(layer, "nranks") and layer.nranks > 0:
|
if hasattr(layer, "nranks") and layer.nranks > 1:
|
||||||
split_axis = extra_weight_attrs.get("split_axis")
|
|
||||||
_set_var_distributed(layer.weight, split_axis=split_axis)
|
|
||||||
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
|
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
|
||||||
|
|
||||||
def process_loaded_weights(self, layer, weights) -> None:
|
def process_loaded_weights(self, layer, weights) -> None:
|
||||||
@@ -125,6 +122,10 @@ class LinearBase(nn.Layer):
|
|||||||
# key
|
# key
|
||||||
if weight_key:
|
if weight_key:
|
||||||
self.weight_key = f"{prefix}.{weight_key}"
|
self.weight_key = f"{prefix}.{weight_key}"
|
||||||
|
elif fd_config.model_config.is_quantized and not skip_quant:
|
||||||
|
self.weight_key = f"{prefix}.quant_weight"
|
||||||
|
self.weight_scale_key = f"{prefix}.weight_scale"
|
||||||
|
self.act_scale_key = f"{prefix}.activation_scale"
|
||||||
else:
|
else:
|
||||||
self.weight_key = f"{prefix}.weight"
|
self.weight_key = f"{prefix}.weight"
|
||||||
self.bias_key = f"{prefix}.bias"
|
self.bias_key = f"{prefix}.bias"
|
||||||
@@ -173,7 +174,11 @@ class LinearBase(nn.Layer):
|
|||||||
Args:
|
Args:
|
||||||
state_dict (dict): A dictionary containing the prequantized weights and scales.
|
state_dict (dict): A dictionary containing the prequantized weights and scales.
|
||||||
"""
|
"""
|
||||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||||
|
# for gate
|
||||||
|
self.load_weight(state_dict)
|
||||||
|
else:
|
||||||
|
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||||
|
|
||||||
def load_weight(self, state_dict: dict):
|
def load_weight(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
@@ -333,18 +338,18 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
self,
|
self,
|
||||||
split_axis=1,
|
|
||||||
output_dim=True,
|
output_dim=True,
|
||||||
weight_loader=(
|
weight_loader=(
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if self.nranks > 0:
|
||||||
if self.with_bias:
|
_set_var_distributed(self.weight, split_axis=1)
|
||||||
if self.nranks > 0:
|
if self.with_bias:
|
||||||
# col parallel
|
# col parallel
|
||||||
_set_var_distributed(self.bias, split_axis=1)
|
_set_var_distributed(self.bias, split_axis=1)
|
||||||
set_weight_attrs(self.bias, {"output_dim": True})
|
if self.nranks > 1:
|
||||||
|
set_weight_attrs(self.bias, {"output_dim": True})
|
||||||
|
|
||||||
|
|
||||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||||
@@ -669,15 +674,19 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if self.nranks > 0:
|
||||||
|
_set_var_distributed(self.weight, split_axis=0)
|
||||||
|
if self.with_bias:
|
||||||
|
# col parallel
|
||||||
|
_set_var_distributed(self.bias, split_axis=0)
|
||||||
|
if self.nranks > 1:
|
||||||
|
set_weight_attrs(
|
||||||
|
self.bias,
|
||||||
|
{
|
||||||
|
"output_dim": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if self.with_bias:
|
|
||||||
_set_var_distributed(self.bias, split_axis=0)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.bias,
|
|
||||||
{
|
|
||||||
"output_dim": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
@@ -60,6 +60,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
self.bias_key: Optional[str] = None
|
self.bias_key: Optional[str] = None
|
||||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||||
self.column_cut = True
|
self.column_cut = True
|
||||||
|
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||||
@@ -91,7 +92,8 @@ class ParallelLMHead(nn.Layer):
|
|||||||
gather_output=need_gather,
|
gather_output=need_gather,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
fuse_matmul_bias=False, # False diff更小
|
||||||
)
|
)
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
if self.nranks > 1:
|
||||||
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
else:
|
else:
|
||||||
self.linear = RowParallelLinear(
|
self.linear = RowParallelLinear(
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
@@ -102,7 +104,8 @@ class ParallelLMHead(nn.Layer):
|
|||||||
input_is_parallel=False,
|
input_is_parallel=False,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
fuse_matmul_bias=False, # False diff更小
|
||||||
)
|
)
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
if self.nranks > 1:
|
||||||
|
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||||
"""
|
"""
|
||||||
|
@@ -83,7 +83,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
|
|
||||||
def create_weights(self, layer, **extra_weight_attrs):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
layer.weight_shape.reverse()
|
layer.weight_shape.reverse()
|
||||||
|
layer.weight_dtype = "float8_e4m3fn"
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=layer.weight_shape,
|
shape=layer.weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -101,7 +101,6 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
dtype="float32",
|
dtype="float32",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
)
|
)
|
||||||
layer.weight_dtype = "float8_e4m3fn"
|
|
||||||
|
|
||||||
def process_loaded_weights(self, layer, weights) -> None:
|
def process_loaded_weights(self, layer, weights) -> None:
|
||||||
weight_tensor = weights.transpose([1, 0])
|
weight_tensor = weights.transpose([1, 0])
|
||||||
|
Reference in New Issue
Block a user