Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -19,11 +19,18 @@ from typing import Dict, List, Type
from .quant_base import QuantConfigBase
QUANTIZATION_METHODS: List[str] = [
"wint2",
"wint4",
"wint8",
"weight_only",
"block_wise",
"block_wise_fp8",
"w4afp8",
"w8a8",
"w4a8",
"wfp8afp8",
"mix_quant",
"tensor_wise_fp8",
"kvcache",
]
@@ -34,20 +41,30 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
from .block_wise import BlockWiseConfig
from .block_wise_fp8 import BlockWiseFP8Config
from .kv_cache import KvCacheQuantConfig
from .mix_quant import MixQuantConfig
from .tensor_wise_fp8 import TensorWiseFP8Config
from .w4a8 import W4A8Config
from .w4afp8 import W4AFP8Config
from .w8a8 import W8A8Config
from .weight_only import WeightOnlyConfig
from .weight_only import WeightOnlyConfig, WINT4Config, WINT8Config
from .wfp8afp8 import WFP8AFP8Config
from .kv_cache import KvCacheQuantConfig
from .wint2 import WINT2Config
method_to_config: Dict[str, Type[QuantConfigBase]] = {
"wint2": WINT2Config,
"wint4": WINT4Config,
"wint8": WINT8Config,
"weight_only": WeightOnlyConfig,
"block_wise": BlockWiseConfig,
"block_wise_fp8": BlockWiseFP8Config,
"w4afp8": W4AFP8Config,
"w8a8": W8A8Config,
"w4a8": W4A8Config,
"wfp8afp8": WFP8AFP8Config,
"kvcache": KvCacheQuantConfig
"tensor_wise_fp8": TensorWiseFP8Config,
"kvcache": KvCacheQuantConfig,
"mix_quant": MixQuantConfig,
}
return method_to_config[quantization]

View File

@@ -18,16 +18,13 @@ from typing import Optional
import paddle
import fastdeploy
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
from fastdeploy.model_executor.layers.moe import FusedMoE
from ..utils import per_block_cast_to_fp8
from ..utils import per_block_cast_to_fp8, get_tensor
from .quant_base import QuantConfigBase, QuantMethodBase
QUANT_ALIGNMENT_OFFSET = 127
QUANT_BLOCK_SIZE = 128
class BlockWiseConfig(QuantConfigBase):
class BlockWiseFP8Config(QuantConfigBase):
"""
block wise quantization config, only support fp8 quant and only supports loading weights in BF16 format.
After loading the weights, it will automatically compute quantization sparsity and dynamically perform
@@ -37,41 +34,55 @@ class BlockWiseConfig(QuantConfigBase):
def __init__(self, weight_block_size: list = [-1, -1]) -> None:
super().__init__()
self.weight_block_size = weight_block_size
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
def get_name(self) -> str:
return "block_wise"
def name(self) -> str:
return "block_wise_fp8"
@classmethod
def from_config(cls, config: dict) -> "BlockWiseConfig":
weight_block_size = config["weight_block_size"]
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
weight_block_size = config.get("weight_block_size", [128, 128])
return cls(weight_block_size)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
return BlockWiseLinearMethod(self)
'''
Get quantization method.
'''
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
DeepGemmFusedMoeMethod
return DeepGemmFusedMoeMethod(self)
else:
return BlockWiseFP8LinearMethod(self)
class BlockWiseLinearMethod(QuantMethodBase):
class BlockWiseFP8LinearMethod(QuantMethodBase):
"""
block wise quantization method for linear
"""
def __init__(
self,
quant_config: BlockWiseConfig,
quant_config: BlockWiseFP8Config,
) -> None:
super().__init__()
self.quant_config = quant_config
def create_weights(self, layer):
layer.linear_weight_scale = self.create_parameter(
layer.linear_weight_shape.reverse()
layer.linear_weight_scale = layer.create_parameter(
shape=[
(layer.embed_dim + QUANT_ALIGNMENT_OFFSET) // QUANT_BLOCK_SIZE,
(layer.num_heads * layer.head_dim + QUANT_ALIGNMENT_OFFSET) //
QUANT_BLOCK_SIZE,
(layer.output_size + self.quant_config.weight_block_size[0] -
1) // self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
layer.weight_dtype = "float8_e4m3fn"
def process_loaded_weights(self, layer, weights) -> None:
weight_tensor = weights.transpose([1, 0])
@@ -80,15 +91,30 @@ class BlockWiseLinearMethod(QuantMethodBase):
layer.linear_weight.copy_(quanted_weight_tensor, False)
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
def process_prequanted_weights(self, layer, state_dict):
"""
process_prequanted_weights
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
quant_weight = quant_weight.transpose([1, 0]).contiguous()
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
weight_scale = weight_scale.transpose([1, 0])
layer.linear_weight_scale.set_value(weight_scale)
def apply(self, layer, x):
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
x, self.quant_config.weight_block_size[0])
linear_out = paddle.empty(
(x.shape[0], layer.llm_config.model_config.hidden_size),
dtype=paddle.bfloat16)
linear_out = paddle.empty((x.shape[0], layer.output_size),
dtype=paddle.bfloat16)
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
deep_gemm.gemm_fp8_fp8_bf16_nt(
(x, x_scale_tensor),
(layer.linear_weight, layer.linear_weight_scale),
linear_out,
)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.linear_bias)
return linear_out

View File

@@ -13,38 +13,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from paddle import nn
import os
import paddle
from .quant_base import QuantConfigBase, QuantMethodBase
from enum import Enum
from typing import Optional
import paddle
from paddle import nn
from fastdeploy.model_executor.layers.utils import get_tensor
from ..utils import create_and_set_parameter
from .quant_base import QuantConfigBase, QuantMethodBase
class KvCacheQuantzationTypes(str, Enum):
"""
KvCacheQuantzationTypes
"""
INT8 = "int8"
FP8 = "float8_e4m3fn"
INT8_ZP = "int8_zp"
FP8_ZP = "float8_e4m3fn_zp"
class KvCacheQuantConfig(QuantConfigBase):
"""
quantization config for weight 4bits and activation fp8
"""
def __init__(self, cachekv_scale_dict) -> None:
def __init__(self, kv_cache_quant_type: str) -> None:
"""
__init__
"""
super().__init__()
self.cachekv_scale_dict = cachekv_scale_dict
self.kv_cache_quant_type = kv_cache_quant_type
def get_name(self) -> str:
try:
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
except ValueError:
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
self.has_zero_point = "zp" in kv_cache_quant_type
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
self.max_bound = 127.0
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
self.max_bound = 448.0
else:
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
def name(self) -> str:
"""
get_name
"""
return "kvcache"
@classmethod
def from_config(cls, config: dict) -> "KvCacheQuantConfig":
def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig":
"""
from_config
"""
cachekv_scale_dict = config["cachekv_scale_dict"]
return cls(cachekv_scale_dict)
return cls(kv_cache_quant_type)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
@@ -66,197 +94,63 @@ class KVCacheMethodBase(QuantMethodBase):
KVCacheMethodBase __init__
"""
super().__init__()
self.quant_config = quant_config
self.cache_quant_config = quant_config
def load_zp(self, layer: nn.Layer):
def load_zp(self, layer: nn.Layer, state_dict):
"""
load_zp
"""
if self.cache_k_zp_name in self.quant_config.cachekv_scale_dict:
cache_k_zp = paddle.cast(
paddle.to_tensor(
self.quant_config.cachekv_scale_dict[self.cache_k_zp_name]
),
self.cache_scale_dtype,
)
else:
cache_k_zp = paddle.zeros(
(
[self.kv_num_heads * self.head_dim]
if self.quant_config.is_channel_wise
else [self.kv_num_heads]
),
dtype=self.cache_scale_dtype,
)
if self.cache_v_zp_name in self.quant_config.cachekv_scale_dict:
cache_v_zp = paddle.cast(
paddle.to_tensor(
self.quant_config.cachekv_scale_dict[self.cache_v_zp_name]
),
self.cache_scale_dtype,
)
else:
cache_v_zp = paddle.zeros(
(
[self.kv_num_heads * self.head_dim]
if self.quant_config.is_channel_wise
else [self.kv_num_heads]
),
dtype=self.cache_scale_dtype,
)
layer.cache_k_zp.set_value(cache_k_zp)
layer.cache_v_zp.set_value(cache_v_zp)
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name))
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name))
def load_scale(self, layer: nn.Layer):
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
def load_scale(self, layer: nn.Layer, state_dict):
"""
load_scale
"""
if self.cache_k_scale_name in self.quant_config.cachekv_scale_dict:
cache_k_scale = paddle.cast(
paddle.to_tensor(
self.quant_config.cachekv_scale_dict[self.cache_k_scale_name]
),
self.cache_scale_dtype,
)
cache_k_out_scale = 1.0 / cache_k_scale
else:
raise KeyError(
f"{self.cache_k_scale_name} not found in scale dict")
cache_k_scale_tensor = get_tensor(
state_dict.pop(self.cache_k_scale_name)).cast(
paddle.get_default_dtype()).reshape_([-1])
cache_v_scale_tensor = get_tensor(
state_dict.pop(self.cache_v_scale_name)).cast(
paddle.get_default_dtype()).reshape_([-1])
if self.cache_v_scale_name in self.quant_config.cachekv_scale_dict:
cache_v_scale = paddle.cast(
paddle.to_tensor(
self.quant_config.cachekv_scale_dict[self.cache_v_scale_name]
),
self.cache_scale_dtype,
)
cache_v_out_scale = 1.0 / cache_v_scale
else:
raise KeyError(
f"{self.cache_v_scale_name} not found in scale dict")
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
if self.cache_v_scale_name in self.quant_config.cachekv_scale_dict:
cache_v_scale = paddle.cast(
paddle.to_tensor(
self.quant_config.cachekv_scale_dict[self.cache_v_scale_name]
),
self.cache_scale_dtype,
)
cache_v_out_scale = 1.0 / cache_v_scale
else:
raise KeyError(
f"{self.cache_v_scale_name} not found in scale dict")
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
create_and_set_parameter(layer, "cache_k_out_scale", cache_k_out_scale)
create_and_set_parameter(layer, "cache_v_out_scale", cache_v_out_scale)
layer.cache_k_scale.set_value(cache_k_scale)
layer.cache_v_scale.set_value(cache_v_scale)
layer.cache_k_out_scale.set_value(cache_k_out_scale)
layer.cache_v_out_scale.set_value(cache_v_out_scale)
def create_scale(self, layer: nn.Layer):
"""
create_scale
"""
layer.cache_k_scale = layer.create_parameter(
shape=(
[layer.kv_num_heads * layer.head_dim]
if self.quant_config.is_channel_wise
else [layer.kv_num_heads]
),
dtype=self.cache_scale_dtype,
is_bias=False,
)
layer.cache_v_scale = layer.create_parameter(
shape=(
[layer.kv_num_heads * layer.head_dim]
if self.quant_config.is_channel_wise
else [layer.kv_num_heads]
),
dtype=self.cache_scale_dtype,
is_bias=False,
)
layer.cache_k_out_scale = layer.create_parameter(
shape=(
[layer.kv_num_heads * layer.head_dim]
if self.quant_config.is_channel_wise
else [layer.kv_num_heads]
),
attr=None,
dtype=self.cache_scale_dtype,
is_bias=False,
)
layer.cache_v_out_scale = layer.create_parameter(
shape=(
[layer.kv_num_heads * layer.head_dim]
if self.quant_config.is_channel_wise
else [layer.kv_num_heads]
),
attr=None,
dtype=self.cache_scale_dtype,
is_bias=False,
)
def create_zp(self, layer: nn.Layer):
"""
create_zp
"""
layer.cache_k_zp = layer.create_parameter(
shape=(
[layer.kv_num_heads * layer.head_dim]
if self.quant_config.is_channel_wise
else [layer.kv_num_heads]
),
dtype=self.cache_scale_dtype,
is_bias=False,
)
layer.cache_v_zp = layer.create_parameter(
shape=(
[layer.kv_num_heads * layer.head_dim]
if self.quant_config.is_channel_wise
else [layer.kv_num_heads]
),
dtype=self.cache_scale_dtype,
is_bias=False,
)
def create_weights(self, layer: nn.Layer):
def create_weights(self, layer: nn.Layer, state_dict):
"""
create_weights
"""
self.prefix = layer.prefix
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_quanter"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_quanter"
self.cache_k_zp_name = layer.cache_k_scale_name + ".zero_point"
self.cache_v_zp_name = layer.cache_v_scale_name + ".zero_point"
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
layer.cache_k_zp = None
layer.cache_v_zp = None
layer.cache_k_scale = None
layer.cache_v_scale = None
layer.cache_k_out_scale = None
layer.cache_v_out_scale = None
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
setattr(layer, "cache_quant_type_str", "cache_int8")
setattr(layer, "quant_max_bound", 127.0)
setattr(layer, "quant_min_bound", -127.0)
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.FP8:
setattr(layer, "cache_quant_type_str", "cache_fp8")
setattr(layer, "quant_max_bound", 448.0)
setattr(layer, "quant_min_bound", -448.0)
else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
self._dtype = layer._dtype
if self._dtype != "bfloat16" and self._dtype != "float16" and self._dtype == "float32":
raise ValueError(
f"Just support float32, float16 and \
bfloat16 as default dtype, but received {self._dtype}"
)
self.cache_scale_dtype = (
self._dtype if self.quant_config.use_append_attn else "float32"
)
if not self.quant_config.use_dynamic_cachekv_quant:
if (
self.quant_config.cachekv_dtype == "int8"
or self.quant_config.cachekv_dtype == "int4"
or self.quant_config.cachekv_dtype == "float8_e4m3fn"
):
self.create_scale(layer)
self.load_scale(layer)
if self.quant_config.has_zero_point:
self.create_zp(layer)
self.load_zp(layer)
layer.cache_quant_type_str = self.quant_config.cache_quant_type
self.load_scale(layer, state_dict)
if self.cache_quant_config.has_zero_point:
self.load_zp(layer, state_dict)
def apply(self, layer):
"""
@@ -264,4 +158,3 @@ class KVCacheMethodBase(QuantMethodBase):
"""
raise RuntimeError(
f"{self.__class__.__name__}.apply should not be called.")

View File

@@ -0,0 +1,75 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
from ..attention import Attention
from ..moe import FusedMoE
from . import get_quantization_config
from .quant_base import QuantConfigBase, QuantMethodBase
class MixQuantConfig(QuantConfigBase):
"""
Quantization config for layers that has different quantization methods.
"""
def __init__(
self,
dense_quant_type: str,
moe_quant_type: str,
kv_cache_quant_type: str = None,
image_moe_quant_type: str = None,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
self.moe_quant_type = moe_quant_type
self.kv_cache_quant_type = kv_cache_quant_type
if image_moe_quant_type is None:
self.image_moe_quant_type = moe_quant_type
else:
self.image_moe_quant_type = image_moe_quant_type
self.quant_max_bound = 0
self.quant_min_bound = 0
self.quant_round_type = 0
def name(self) -> str:
return "mix_quant"
@classmethod
def from_config(cls, config: dict) -> "MixQuantConfig":
return cls(config['dense_quant_type'], config['moe_quant_type'],
config.get('kv_cache_quant_type', None),
config.get('image_moe_quant_type', None))
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):
if layer.moe_tag == "Image":
return get_quantization_config(
self.image_moe_quant_type).from_config(
{}).get_quant_method(layer)
else:
return get_quantization_config(
self.moe_quant_type).from_config(
{}).get_quant_method(layer)
elif isinstance(layer, Attention):
if self.kv_cache_quant_type is not None:
return (get_quantization_config("kvcache").from_config(
self.kv_cache_quant_type).get_quant_method(layer))
else:
return None
else:
return get_quantization_config(self.dense_quant_type).from_config(
{}).get_quant_method(layer)

View File

@@ -0,0 +1,22 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .cutlass_scaled_mm import cutlass_scaled_mm
from .scaled_fp8_quant import scaled_fp8_quant
__all__ = [
"cutlass_scaled_mm",
"scaled_fp8_quant",
]

View File

@@ -0,0 +1,126 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
import fastdeploy
def cutlass_scaled_mm(a: paddle.Tensor,
b: paddle.Tensor,
scale_a: paddle.Tensor,
scale_b: paddle.Tensor,
out_dtype: paddle.dtype,
bias: Optional[paddle.Tensor] = None) -> paddle.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
`output = paddle.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (out_dtype == paddle.bfloat16 or out_dtype == paddle.float16)
assert bias is None or bias.shape[0] == b.shape[
0] and bias.dtype == out_dtype
# Ensure input tensors have valid shapes
# assert a.numel() > 0, "Input tensor 'a' must not be empty"
# assert b.numel() > 0, "Input tensor 'b' must not be empty"
# assert scale_a.numel() > 0, "Scale tensor 'scale_a' must not be empty"
# assert scale_b.numel() > 0, "Scale tensor 'scale_b' must not be empty"
m = a.shape[0]
n = b.shape[0]
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert cutlass_compatible_b
out = paddle.empty([m, n], dtype=out_dtype)
fastdeploy.model_executor.ops.gpu.cutlass_scaled_mm(
out, a, b, scale_a, scale_b, bias)
return out
def scaled_fp8_quant(
input: paddle.Tensor,
scale: Optional[paddle.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: float = 0,
use_per_token_if_dynamic: bool = False,
) -> tuple[paddle.Tensor, paddle.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[paddle.Tensor, paddle.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = paddle.empty(shape, dtype=paddle.float8_e4m3fn)
if scale is None:
if use_per_token_if_dynamic:
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
from fastdeploy.model_executor.ops.gpu import \
dynamic_per_token_scaled_fp8_quant
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
else:
scale = paddle.zeros([1], dtype=paddle.float32)
from fastdeploy.model_executor.ops.gpu import \
dynamic_scaled_fp8_quant
dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
# assert (scale.numel() == 1 or num_token_padding is None)
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
static_scaled_fp8_quant(output, input, scale)
return output, scale

View File

@@ -0,0 +1,75 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
def scaled_fp8_quant(
input: paddle.Tensor,
scale: Optional[paddle.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: float = 0,
use_per_token_if_dynamic: bool = False,
) -> tuple[paddle.Tensor, paddle.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[paddle.Tensor, paddle.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = paddle.empty(shape, dtype=paddle.float8_e4m3fn)
if scale is None:
if use_per_token_if_dynamic:
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
from fastdeploy.model_executor.ops.gpu import \
dynamic_per_token_scaled_fp8_quant
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
else:
scale = paddle.zeros([1], dtype=paddle.float32)
from fastdeploy.model_executor.ops.gpu import \
dynamic_scaled_fp8_quant
dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
# assert (scale.numel() == 1 or num_token_padding is None)
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
static_scaled_fp8_quant(output, input, scale)
return output, scale

View File

@@ -47,12 +47,9 @@ class QuantConfigBase(ABC):
def __init__(self):
super().__init__()
self.quant_round_type = None
self.quant_max_bound = None
self.quant_min_bound = None
@abstractmethod
def get_name(self) -> str:
def name(self) -> str:
"""Name of the quantization method."""
raise NotImplementedError

View File

@@ -0,0 +1,135 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.model_executor.layers.moe import FusedMoE
from ..utils import get_tensor
from .quant_base import QuantConfigBase, QuantMethodBase
class TensorWiseFP8Config(QuantConfigBase):
"""
Quantization config for weight and activation with FP8.
"""
def __init__(self) -> None:
"""
Nothing else to do!
"""
super().__init__()
def name(self) -> str:
"""
Nothing else to do!
"""
return "tensor_wise_fp8"
@classmethod
def from_config(cls, config: dict) -> "TensorWiseFP8Config":
"""
Nothing else to do!
"""
return cls()
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
return method according to this config!
"""
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
TensorWiseFP8MoEMethod
return TensorWiseFP8MoEMethod(self)
else:
return TensorWiseFP8LinearMethod(self)
class TensorWiseFP8LinearMethod(QuantMethodBase):
"""
Weight and activation quantization method for linear layer with per tensor FP8
"""
def __init__(
self,
quant_config: TensorWiseFP8Config,
) -> None:
"""
Nothing special to do!
"""
super().__init__()
self.quant_config = quant_config
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.weight_dtype = "float8_e4m3fn"
def create_weights(self, layer):
"""
Nothing to do!
"""
pass
def process_prequanted_weights(self, layer, state_dict) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:
layer: The layer that owns the weights
quant_weight: The quantized weights
weight_scale: The scale of the quantized weights
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
quant_weight = quant_weight.transpose([1, 0]).contiguous()
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
self.act_scale = act_scale.item()
self.total_scale = (act_scale * weight_scale).item()
def process_loaded_weights(self, layer, weights, state_dict) -> None:
"""
Read fp8 weight, act scale, weight scale
"""
pass
def apply(self, layer, x):
"""
compute!
"""
from fastdeploy.model_executor.ops.gpu import \
cutlass_fp8_fp8_half_gemm_fused
from ..utils import create_hadamard_matrix_map
hadamard_matrix = create_hadamard_matrix_map[x.shape[-1]]
new_x = paddle.matmul(x.cast("float32"), hadamard_matrix)
fp8_x = new_x / self.act_scale
fp8_x = fp8_x.astype("float8_e4m3fn")
linear_out = cutlass_fp8_fp8_half_gemm_fused(
fp8_x,
layer.linear_weight,
transpose_x=False,
transpose_y=True,
bias=None,
scale=self.total_scale,
output_dtype="bfloat16",
activation_type="identity")
return linear_out

View File

@@ -0,0 +1,42 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
from ..moe import FusedMoE
from .quant_base import QuantConfigBase, QuantMethodBase
class W4A8Config(QuantConfigBase):
"""
quantization config for weight 4bits and activation 8bits
"""
def __init__(self) -> None:
super().__init__()
def name(self) -> str:
return "w4a8"
@classmethod
def from_config(cls, config: dict) -> "W4A8Config":
return cls()
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import CutlassW4A8MoEMethod
return CutlassW4A8MoEMethod(self)
else:
raise ValueError(f"Unsupported layer type {type(layer)} for w4a8")

View File

@@ -23,16 +23,21 @@ from .quant_base import QuantConfigBase, QuantMethodBase
QUANT_SCALING_FACTOR = 448
class W4AFP8Config(QuantConfigBase):
"""
quantization config for weight 4bits and activation fp8
"""
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
def get_name(self) -> str:
def name(self) -> str:
return "w4afp8"
@classmethod
@@ -49,6 +54,7 @@ class W4AFP8LinearMethod(QuantMethodBase):
"""
W4 AFP8 quant method for linear
"""
def __init__(
self,
quant_config: W4AFP8Config,
@@ -57,6 +63,9 @@ class W4AFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8"
pass
def process_loaded_weights(self, layer, weights) -> None:
@@ -78,11 +87,11 @@ class W4AFP8LinearMethod(QuantMethodBase):
layer.linear_weight_scale,
zero_points=None,
bias=layer.linear_bias if layer.add_bias else None,
out_scale=self.quant_config.weight_scale_dict.get(
layer.prefix + ".weight_quanter") /
(self.quant_config.act_scale_dict.get(layer.prefix +
".activation_quanter") *
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
".weight_scale")
/ (self.quant_config.act_scale_dict.get(layer.prefix +
".activation_scale") *
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
groupsize=0,
out_dtype=layer._dtype,
)

View File

@@ -16,11 +16,12 @@
from typing import Optional
import paddle
from paddlenlp.utils.log import logger
from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
from ..utils import get_tensor
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -29,14 +30,18 @@ class W8A8Config(QuantConfigBase):
quantization config for weight 8bits and activation 8bits
"""
def __init__(self, weight_scale_dict, act_scale_dict,
use_gemm_dequant) -> None:
def __init__(self, weight_scale_dict, act_scale_dict, use_gemm_dequant,
use_smooth_quant) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.use_gemm_dequant = use_gemm_dequant
self.use_smooth_quant = use_smooth_quant
self.quant_max_bound = 127
self.quant_min_bound = -127
self.quant_round_type = 0
def get_name(self) -> str:
def name(self) -> str:
return "w8a8"
@classmethod
@@ -61,12 +66,17 @@ class W8A8LinearMethod(QuantMethodBase):
) -> None:
super().__init__()
self.quant_config = quant_config
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
def create_weights(self, layer):
weight_scale = self.quant_config.weight_scale_dict.get(
layer.prefix + ".weight_quanter")
layer.linear_weight_shape.reverse()
layer.weight_dtype = "int8"
if self.quant_config.use_smooth_quant:
self.smooth_quant_method.create_weights(layer)
weight_scale = self.quant_config.weight_scale_dict.get(layer.prefix +
".weight_scale")
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
".activation_quanter")
".activation_scale")
self.skip_quant = False
if weight_scale is None or in_scale is None:
self.skip_quant = True
@@ -86,13 +96,15 @@ class W8A8LinearMethod(QuantMethodBase):
convert_to_npu_dequant_scale(linear_out_scale))
def process_loaded_weights(self, layer, weights) -> None:
if self.quant_config.use_smooth_quant:
self.smooth_quant_method.process_loaded_weights(layer, weights)
if self.skip_quant:
logger.debug(f"{layer.prefix} skip quant")
weight_tensor = weights.cast(layer._dtype)
layer.linear_weight.set_value(weight_tensor)
else:
weight_tensor = weights.transpose([1, 0])
weight_tensor = paddle.cast(weight_tensor, layer.weight_dtype)
weight_tensor = paddle.cast(weight_tensor, "int8")
layer.linear_weight.set_value(weight_tensor)
def apply(self, layer, x):
@@ -107,3 +119,53 @@ class W8A8LinearMethod(QuantMethodBase):
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
linear_out, layer.linear_out_scale, layer._dtype)
return linear_out
class SmoothQuantLinearMethod(QuantMethodBase):
"""
SmoothQuant Method
"""
def __init__(
self,
quant_config: QuantConfigBase,
) -> None:
super().__init__()
self.quant_config = quant_config
def create_weights(self, layer):
linear_shift_shape = [layer.output_size]
linear_smooth_shape = [layer.output_size]
layer.linear_shift = self.create_parameter(
shape=linear_shift_shape,
dtype=layer._dtype,
is_bias=False,
)
layer.linear_smooth = layer.create_parameter(
shape=linear_smooth_shape,
dtype=layer._dtype,
is_bias=False,
)
def process_loaded_weights(self, layer, weights) -> None:
if layer.shift_key in layer.state_dict:
shift_tensor = get_tensor(layer.state_dict.pop(
layer.shift_key)).astype(paddle.get_default_dtype())
else:
shift_tensor = paddle.zeros(
shape=layer.linear_shift_shape,
dtype=paddle.get_default_dtype(),
)
layer.linear_shift.set_value(shift_tensor)
if layer.smooth_key in layer.state_dict:
smooth_tensor = get_tensor(layer.state_dict.pop(
layer.smooth_key)).astype(paddle.get_default_dtype())
else:
smooth_tensor = paddle.ones(
shape=[layer.linear_smooth_shape],
dtype=paddle.get_default_dtype(),
)
layer.linear_smooth.set_value(smooth_tensor)
def apply(self, layer, x):
pass

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from abc import abstractmethod
from typing import Optional
@@ -21,6 +22,8 @@ from paddle.nn.quant import weight_only_linear, weight_quantize
from fastdeploy.platforms import current_platform
from ..moe import FusedMoE
from ..utils import get_tensor
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -28,34 +31,92 @@ class WeightOnlyConfig(QuantConfigBase):
"""
Quantization config for weight only
Args:
weight_only_linear_arch: The architecture of weight only linear layer
algo: The quant algorithm("weight_only_int8" or "weight_only_int4") used for weight only linear layer
"""
def __init__(
self,
weight_only_linear_arch: int,
algo: str,
) -> None:
super().__init__()
self.weight_only_linear_arch = weight_only_linear_arch
self.algo = algo
# arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70,
# if you do not assign arch, we will get arch from your device, default: None.
self.weight_only_linear_arch = os.getenv(
"FLAGS_weight_only_linear_arch")
if self.weight_only_linear_arch is not None:
self.weight_only_linear_arch = int(self.weight_only_linear_arch)
self.quant_max_bound = 0
self.quant_min_bound = 0
self.quant_round_type = 0
def get_name(self) -> str:
def name(self) -> str:
return "weight_only"
@classmethod
def from_config(cls, config: dict) -> "WeightOnlyConfig":
weight_only_linear_arch = config["weight_only_linear_arch"]
algo = config["algo"]
return cls(weight_only_linear_arch, algo)
return cls(algo)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if current_platform.is_xpu():
from fastdeploy.model_executor.layers.backends import XPUWeightOnlyLinearMethod
return XPUWeightOnlyLinearMethod(self)
from fastdeploy.model_executor.layers.backends import (
XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod)
if isinstance(layer, FusedMoE):
return XPUWeightOnlyMoEMethod(self)
else:
return XPUWeightOnlyLinearMethod(self)
else:
return GPUWeightOnlyLinearMethod(self)
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import \
CutlassWeightOnlyMoEMethod
return CutlassWeightOnlyMoEMethod(self)
elif layer.use_method == "triton":
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
TritonWeightOnlyMoEMethod
return TritonWeightOnlyMoEMethod(self)
elif layer.use_method == "marlin":
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import \
MarlinWeightOnlyMoEMethod
return MarlinWeightOnlyMoEMethod(self)
else:
raise ValueError(
f"Unsupported MOE backend {layer.use_method}")
else:
return GPUWeightOnlyLinearMethod(self)
class WINT8Config(WeightOnlyConfig):
"""
weight only int8 config
"""
def __init__(self, ) -> None:
super().__init__("weight_only_int8")
@classmethod
def from_config(cls, config: dict) -> "WINT8Config":
return cls()
def name(self) -> str:
return "wint8"
class WINT4Config(WeightOnlyConfig):
"""
weight only int4 config
"""
def __init__(self, ) -> None:
super().__init__("weight_only_int4")
@classmethod
def from_config(cls, config: dict) -> "WINT4Config":
return cls()
def name(self) -> str:
return "wint4"
class WeightOnlyLinearMethod(QuantMethodBase):
@@ -71,12 +132,17 @@ class WeightOnlyLinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
weight_only_scale_name = layer.prefix + ".weight_only_scale"
layer.linear_weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8"
linear_weight_scale_shape = [layer.embed_dim]
if hasattr(layer, "linear_weight_shape"):
if isinstance(layer.linear_weight_shape, list):
layer_weight_shape = layer.linear_weight_shape
linear_weight_scale_shape = layer_weight_shape[:1]
if self.quant_config.name() == "wint4":
linear_weight_scale_shape[0] *= 2
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
@@ -94,7 +160,8 @@ class WeightOnlyLinearMethod(QuantMethodBase):
weight=layer.linear_weight,
bias=layer.linear_bias if layer.add_bias else None,
weight_scale=layer.linear_weight_scale,
weight_dtype=layer.weight_dtype,
weight_dtype="int8"
if self.quant_config.name() == "wint8" else "int4",
arch=self.quant_config.weight_only_linear_arch,
)
return linear_out
@@ -113,6 +180,20 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
) -> None:
super().__init__(quant_config)
def process_prequanted_weights(self, layer, state_dict) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:
layer: The layer that owns the weights
quant_weight: The quantized weights
weight_scale: The scale of the quantized weights
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
weight,

View File

@@ -17,10 +17,10 @@ from typing import Optional
import paddle
import fastdeploy
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
from .quant_base import QuantConfigBase, QuantMethodBase
from fastdeploy.model_executor.layers.quantization.ops import (
cutlass_scaled_mm, scaled_fp8_quant)
from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase, QuantMethodBase)
class WFP8AFP8Config(QuantConfigBase):
@@ -32,17 +32,26 @@ class WFP8AFP8Config(QuantConfigBase):
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
def get_name(self) -> str:
def name(self) -> str:
"""
"""
return "wfp8afp8"
@classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config":
weight_scale_dict = config["weight_scale_dict"]
act_scale_dict = config["act_scale_dict"]
"""
"""
weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None)
return cls(weight_scale_dict, act_scale_dict)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
"""
return WFP8AFP8LinearMethod(self)
@@ -59,58 +68,49 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
"""
"""
layer.linear_weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
weight_scale = self.quant_config.weight_scale_dict.get(
layer.prefix + ".weight_quanter")
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
".activation_quanter")
self.skip_quant = False
# we will skip quant if weight_scale is not found or in_scale is not found
if weight_scale is None or in_scale is None:
self.skip_quant = True
else:
max_range = 448.0
layer.scalar_scale_name = layer.prefix + ".scalar_weight_quanter"
layer.scalar_scale = layer.create_parameter(
shape=([1]),
dtype="float32",
)
layer.scalar_scale.set_value(
paddle.to_tensor([1.0 / (max_range * in_scale)],
dtype="float32"))
linear_out_scale = paddle.to_tensor(weight_scale /
max_range).astype("float32")
layer.linear_out_scale = layer.create_parameter(
shape=[layer.embed_dim],
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.linear_out_scale.set_value(
convert_to_npu_dequant_scale(linear_out_scale))
layer.linear_weight_scale = layer.create_parameter(
shape=[1],
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_loaded_weights(self, layer, weights) -> None:
# TODO(YuanRisheng): We should abstract the skip_quant logic to adapt to more quant methods
"""
"""
if self.skip_quant:
weight_tensor = weights.cast(layer._dtype)
layer.linear_weight.set_value(weight_tensor)
return
weight_tensor = weights.transpose([1, 0])
weight_tensor = paddle.cast(weight_tensor, self.weight_dtype)
self.linear_weight.copy_(weight_tensor, False)
if weights.dtype != paddle.float8_e4m3fn:
self.use_per_token_if_dynamic = True
weight_tensor = weights.transpose([1, 0]).contiguous()
qweight, weight_scale = scaled_fp8_quant(
weight_tensor,
use_per_token_if_dynamic=False,
)
layer.linear_weight.copy_(qweight, False)
layer.linear_weight_scale.set_value(weight_scale)
def apply(self, layer, x):
"""
"""
if self.skip_quant:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
return linear_out
linear_out = fastdeploy.model_executor.ops.gpu.per_channel_fp8_fp8_half_gemm_fused(
x,
layer.linear_weight,
bias=layer.linear_bias if layer.add_bias else None,
scalar_scale=layer.scalar_scale,
channel_scale=layer.linear_out_scale,
transpose_x=False,
transpose_y=True,
output_dtype=layer._dtype,
)
if self.use_per_token_if_dynamic:
out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
layer.linear_weight_scale, out_type,
layer.linear_bias)
else:
raise NotImplementedError
return linear_out

View File

@@ -0,0 +1,142 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
from ..moe import FusedMoE
from . import get_quantization_config
from .quant_base import QuantConfigBase, QuantMethodBase
class WINT2Config(QuantConfigBase):
"""
Quantization config for wint8 linear and w4w2 MoE.
"""
def __init__(
self,
dense_quant_type: str,
dense_quant_granularity: str,
moe_quant_type: str,
moe_w4_quant_type: str,
moe_w4_quant_granularity: str,
moe_w4_quant_start_layer: int,
moe_w4_quant_end_layer: int,
moe_w2_quant_type: str,
moe_w2_quant_granularity: str,
moe_w2_quant_group_size: int,
moe_w2_quant_start_layer: int,
moe_w2_quant_end_layer: int,
) -> None:
super().__init__()
self.quant_max_bound = 0
self.quant_min_bound = 0
self.quant_round_type = 0
# wint2 quantization config
self.dense_quant_type = dense_quant_type
self.dense_quant_granularity = dense_quant_granularity
self.moe_quant_type = moe_quant_type
self.moe_w4_quant_type = moe_w4_quant_type
self.moe_w4_quant_granularity = moe_w4_quant_granularity
self.moe_w4_quant_start_layer = moe_w4_quant_start_layer
self.moe_w4_quant_end_layer = moe_w4_quant_end_layer
self.moe_w2_quant_type = moe_w2_quant_type
self.moe_w2_quant_granularity = moe_w2_quant_granularity
self.moe_w2_quant_group_size = moe_w2_quant_group_size
self.moe_w2_quant_start_layer = moe_w2_quant_start_layer
self.moe_w2_quant_end_layer = moe_w2_quant_end_layer
def name(self) -> str:
"""
Get the name of the quantization configuration.
Returns:
str: The name of the quantization configuration.
"""
return "wint2"
@classmethod
def from_config(cls, config: dict) -> "WINT2Config":
"""
Create a new instance of `WINT2Config` using the provided configuration dictionary.
Args:
config (dict): A dictionary containing the configuration parameters for the new instance.
Returns:
WINT2Config: The newly created instance of `WINT2Config`.
"""
dense_quant_type = config.get("dense_quant_config", "wint8")
dense_quant_granularity = config.get("dense_quant_granularity",
"per_channel")
moe_quant_config = config.get("moe_quant_config", {})
moe_quant_type = moe_quant_config.get("quant_type", "w4w2")
moe_w4_quant_config = moe_quant_config.get("moe_w4_quant_config", {})
moe_w4_quant_type = moe_w4_quant_config.get("quant_type",
"wint4")
moe_w4_quant_granularity = moe_w4_quant_config.get(
"quant_granularity", "per_channel")
moe_w4_quant_start_layer = moe_w4_quant_config.get(
"quant_start_layer", 0)
moe_w4_quant_end_layer = moe_w4_quant_config.get("quant_end_layer", 6)
moe_w2_quant_config = moe_quant_config.get("moe_w2_quant_config", {})
moe_w2_quant_type = moe_w2_quant_config.get("quant_type", "wint2")
moe_w2_quant_granularity = moe_w2_quant_config.get(
"quant_granularity", "pp_acc")
moe_w2_quant_group_size = moe_w2_quant_config.get(
"quant_group_size", 0)
moe_w2_quant_start_layer = moe_w2_quant_config.get(
"quant_start_layer", 0)
moe_w2_quant_end_layer = moe_w2_quant_config.get("quant_end_layer", 0)
return cls(
dense_quant_type,
dense_quant_granularity,
moe_quant_type,
moe_w4_quant_type,
moe_w4_quant_granularity,
moe_w4_quant_start_layer,
moe_w4_quant_end_layer,
moe_w2_quant_type,
moe_w2_quant_granularity,
moe_w2_quant_group_size,
moe_w2_quant_start_layer,
moe_w2_quant_end_layer,
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
Get the quantization method associated with the given layer based on the current quantization configuration.
Args:
layer (Layer): The layer for which the quantization method should be retrieved.
Returns:
QuantMethodBase: The quantization method associated with the given layer.
"""
if isinstance(layer, FusedMoE):
if layer.layer_idx <= self.moe_w4_quant_end_layer:
return get_quantization_config(
self.moe_w4_quant_type).from_config(
{}).get_quant_method(layer)
else:
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
TritonWint2FusedMoeMethod
return TritonWint2FusedMoeMethod(self)
else:
return get_quantization_config(self.dense_quant_type).from_config(
{}).get_quant_method(layer)