[v1 loader]qwen Offline fp8 (#4036)

* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
This commit is contained in:
bukejiyu
2025-09-15 13:44:11 +08:00
committed by GitHub
parent b1a5b756a3
commit 29ed617f0f
21 changed files with 440 additions and 138 deletions

View File

@@ -53,7 +53,7 @@ class BlockWiseFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
weight_block_size = config.get("weight_block_size", [128, 128])
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(weight_block_size, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -89,13 +89,15 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
@@ -120,14 +122,28 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
layer.weight_scale_inv = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)
set_weight_attrs(
layer.weight_scale_inv,
{
**extra_weight_attrs,
"is_scale": True,
},
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16: