[v1 loader]support fp8 (#3593)

* support fp8

* update ci
This commit is contained in:
bukejiyu
2025-08-26 17:42:46 +08:00
committed by GitHub
parent 00898603c8
commit 3200a80de3
7 changed files with 463 additions and 160 deletions

View File

@@ -20,7 +20,12 @@ import paddle
import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from ..utils import get_tensor, per_block_cast_to_fp8
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -33,13 +38,14 @@ class BlockWiseFP8Config(QuantConfigBase):
per-token quantization of activations during inference.
"""
def __init__(self, weight_block_size: list = [-1, -1]) -> None:
def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool = False) -> None:
super().__init__()
self.weight_block_size = weight_block_size
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
return "block_wise_fp8"
@@ -47,7 +53,8 @@ class BlockWiseFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
weight_block_size = config.get("weight_block_size", [128, 128])
return cls(weight_block_size)
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(weight_block_size, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
@@ -82,31 +89,78 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
if self.quant_config.is_checkpoint_bf16:
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_inv = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
weight_tensor = layer.weight.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
del layer.weight
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
shape=quanted_weight_tensor.shape,
dtype="float8_e4m3fn",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_inv = layer.create_parameter(
shape=weight_block_scale_tensor.shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale_inv.copy_(weight_block_scale_tensor, False)
def process_loaded_weights(self, layer, weights) -> None:
weight_tensor = weights.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale.set_value(weight_block_scale_tensor)
layer.weight_scale_inv.set_value(weight_block_scale_tensor)
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False):
"""
@@ -119,7 +173,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
weight_scale = weight_scale.transpose([1, 0])
layer.weight_scale.set_value(weight_scale)
layer.weight_scale_inv.set_value(weight_scale)
def apply(self, layer, x):
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
@@ -130,7 +184,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
deep_gemm.gemm_fp8_fp8_bf16_nt(
(x, x_scale_tensor),
(layer.weight, layer.weight_scale),
(layer.weight, layer.weight_scale_inv),
linear_out,
)
if layer.with_bias:

View File

@@ -37,6 +37,7 @@ class MixQuantConfig(QuantConfigBase):
is_channel_wise: bool = False,
has_zero_point: bool = False,
is_permuted: bool = True,
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
@@ -52,6 +53,7 @@ class MixQuantConfig(QuantConfigBase):
self.quant_min_bound = 0
self.quant_round_type = 0
self.is_permuted = is_permuted
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
return "mix_quant"
@@ -66,6 +68,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_channel_wise", False),
config.get("has_zero_point", False),
config.get("is_permuted", True),
config.get("is_checkpoint_bf16", False),
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -73,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image":
return (
get_quantization_config(self.image_moe_quant_type)
.from_config({"is_permuted": self.is_permuted})
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)
else:
return (
get_quantization_config(self.moe_quant_type)
.from_config({"is_permuted": self.is_permuted})
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)
elif isinstance(layer, Attention):
@@ -92,4 +95,8 @@ class MixQuantConfig(QuantConfigBase):
else:
return None
else:
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer)
return (
get_quantization_config(self.dense_quant_type)
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)