[v1 loader]qwen Offline fp8 (#4036)

* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
This commit is contained in:
bukejiyu
2025-09-15 13:44:11 +08:00
committed by GitHub
parent b1a5b756a3
commit 29ed617f0f
21 changed files with 440 additions and 138 deletions

View File

@@ -34,6 +34,72 @@ QUANTIZATION_METHODS: List[str] = [
]
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
# 1.model_config.is_quantized
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
if model_config.model_format == "torch":
quantization_config = model_config.quantization_config
if quantization_config is not None:
model_config.is_quantized = True
else:
quantization_config = model_config.quantization_config
if not model_config.is_quantized:
if quantization_config is not None:
if "is_quantized" in quantization_config:
model_config.is_quantized = quantization_config["is_quantized"]
elif "kv_cache_quant_type" not in quantization_config:
model_config.is_quantized = True
if quantization_config is not None and quantization_config.get("quantization", None) is None:
raise ValueError(
"quantization_config should have a key named 'quantization' for specify quant config."
)
quant_config_name = None
if quantization_config is not None:
quant_config_name = _get_offline_quant_config_name(
quantization_config, model_config.model_format == "torch", is_v1_loader
)
elif args.quantization is not None:
quantization_config = {}
try:
quantization_config.update(args.quantization)
quant_config_name = quantization_config["quantization"]
except:
quant_config_name = args.quantization["quantization"]
quantization_config["quantization"] = quant_config_name
# Special handling for Ernie models
if quant_config_name == "wint4" and is_ernie:
quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
if not quantization_config.get("is_quantized"):
quantization_config["is_quantized"] = model_config.is_quantized
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config)
return quant_config
def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_loader):
if is_torch_weight:
# only support block_wise_fp8 now
quant_method = quantization_config.get("quant_method")
has_block_size = "weight_block_size" in quantization_config
if quant_method == "fp8" and has_block_size:
quant_config_name = "block_wise_fp8"
else:
raise ValueError("Torch weight offline quantization only supports block-wise FP8.")
else:
quant_config_name = quantization_config["quantization"]
return quant_config_name
def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
"""
Get the quantization config class by the quantization name.

View File

@@ -53,7 +53,7 @@ class BlockWiseFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
weight_block_size = config.get("weight_block_size", [128, 128])
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(weight_block_size, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -89,13 +89,15 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
@@ -120,14 +122,28 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
layer.weight_scale_inv = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)
set_weight_attrs(
layer.weight_scale_inv,
{
**extra_weight_attrs,
"is_scale": True,
},
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:

View File

@@ -37,7 +37,7 @@ class MixQuantConfig(QuantConfigBase):
is_channel_wise: bool = False,
has_zero_point: bool = False,
is_permuted: bool = True,
is_checkpoint_bf16: bool = False,
is_quantized: bool = False,
hadamard_block_size: int = 128,
) -> None:
super().__init__()
@@ -54,7 +54,8 @@ class MixQuantConfig(QuantConfigBase):
self.quant_min_bound = 0
self.quant_round_type = 0
self.is_permuted = is_permuted
self.is_checkpoint_bf16 = is_checkpoint_bf16
self.is_checkpoint_bf16 = not is_quantized
self.is_quantized = is_quantized
self.hadamard_block_size = hadamard_block_size
def name(self) -> str:
@@ -70,7 +71,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_channel_wise", False),
config.get("has_zero_point", False),
config.get("is_permuted", True),
config.get("is_checkpoint_bf16", False),
config.get("is_quantized", False),
config.get("hadamard_block_size", 128),
)
@@ -82,7 +83,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"is_checkpoint_bf16": self.is_checkpoint_bf16,
"is_quantized": self.is_quantized,
"hadamard_block_size": self.hadamard_block_size,
}
)
@@ -94,7 +95,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"is_checkpoint_bf16": self.is_checkpoint_bf16,
"is_quantized": self.is_quantized,
"hadamard_block_size": self.hadamard_block_size,
}
)
@@ -112,6 +113,6 @@ class MixQuantConfig(QuantConfigBase):
else:
return (
get_quantization_config(self.dense_quant_type)
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config({"is_quantized": self.is_quantized})
.get_quant_method(layer)
)

View File

@@ -65,7 +65,7 @@ class WeightOnlyConfig(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WeightOnlyConfig":
algo = config["algo"]
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(algo, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -162,7 +162,7 @@ class WINT8Config(WeightOnlyConfig):
@classmethod
def from_config(cls, config: dict) -> "WINT8Config":
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
@@ -182,7 +182,7 @@ class WINT4Config(WeightOnlyConfig):
@classmethod
def from_config(cls, config: dict) -> "WINT4Config":
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
@@ -202,13 +202,15 @@ class WeightOnlyLinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if (
isinstance(layer, MergedColumnParallelLinear)
@@ -256,6 +258,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
{
"weight_loader": weight_loader,
"output_dim": output_dim,
"weight_need_transpose": not extra_weight_attrs.get("model_format") == "torch",
},
)

View File

@@ -60,7 +60,7 @@ class WFP8AFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config":
""" """
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -92,13 +92,14 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
)
scale_shape = scale_shape[::-1]
if self.quant_config.is_checkpoint_bf16:
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {