mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -19,11 +19,18 @@ from typing import Dict, List, Type
|
||||
from .quant_base import QuantConfigBase
|
||||
|
||||
QUANTIZATION_METHODS: List[str] = [
|
||||
"wint2",
|
||||
"wint4",
|
||||
"wint8",
|
||||
"weight_only",
|
||||
"block_wise",
|
||||
"block_wise_fp8",
|
||||
"w4afp8",
|
||||
"w8a8",
|
||||
"w4a8",
|
||||
"wfp8afp8",
|
||||
"mix_quant",
|
||||
"tensor_wise_fp8",
|
||||
"kvcache",
|
||||
]
|
||||
|
||||
|
||||
@@ -34,20 +41,30 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
from .block_wise import BlockWiseConfig
|
||||
from .block_wise_fp8 import BlockWiseFP8Config
|
||||
from .kv_cache import KvCacheQuantConfig
|
||||
from .mix_quant import MixQuantConfig
|
||||
from .tensor_wise_fp8 import TensorWiseFP8Config
|
||||
from .w4a8 import W4A8Config
|
||||
from .w4afp8 import W4AFP8Config
|
||||
from .w8a8 import W8A8Config
|
||||
from .weight_only import WeightOnlyConfig
|
||||
from .weight_only import WeightOnlyConfig, WINT4Config, WINT8Config
|
||||
from .wfp8afp8 import WFP8AFP8Config
|
||||
from .kv_cache import KvCacheQuantConfig
|
||||
|
||||
from .wint2 import WINT2Config
|
||||
|
||||
method_to_config: Dict[str, Type[QuantConfigBase]] = {
|
||||
"wint2": WINT2Config,
|
||||
"wint4": WINT4Config,
|
||||
"wint8": WINT8Config,
|
||||
"weight_only": WeightOnlyConfig,
|
||||
"block_wise": BlockWiseConfig,
|
||||
"block_wise_fp8": BlockWiseFP8Config,
|
||||
"w4afp8": W4AFP8Config,
|
||||
"w8a8": W8A8Config,
|
||||
"w4a8": W4A8Config,
|
||||
"wfp8afp8": WFP8AFP8Config,
|
||||
"kvcache": KvCacheQuantConfig
|
||||
"tensor_wise_fp8": TensorWiseFP8Config,
|
||||
"kvcache": KvCacheQuantConfig,
|
||||
"mix_quant": MixQuantConfig,
|
||||
}
|
||||
|
||||
return method_to_config[quantization]
|
||||
|
@@ -18,16 +18,13 @@ from typing import Optional
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
|
||||
from ..utils import per_block_cast_to_fp8
|
||||
from ..utils import per_block_cast_to_fp8, get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
QUANT_ALIGNMENT_OFFSET = 127
|
||||
QUANT_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
class BlockWiseConfig(QuantConfigBase):
|
||||
class BlockWiseFP8Config(QuantConfigBase):
|
||||
"""
|
||||
block wise quantization config, only support fp8 quant and only supports loading weights in BF16 format.
|
||||
After loading the weights, it will automatically compute quantization sparsity and dynamically perform
|
||||
@@ -37,41 +34,55 @@ class BlockWiseConfig(QuantConfigBase):
|
||||
def __init__(self, weight_block_size: list = [-1, -1]) -> None:
|
||||
super().__init__()
|
||||
self.weight_block_size = weight_block_size
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "block_wise"
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "BlockWiseConfig":
|
||||
weight_block_size = config["weight_block_size"]
|
||||
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
||||
weight_block_size = config.get("weight_block_size", [128, 128])
|
||||
return cls(weight_block_size)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
return BlockWiseLinearMethod(self)
|
||||
'''
|
||||
Get quantization method.
|
||||
'''
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
|
||||
DeepGemmFusedMoeMethod
|
||||
return DeepGemmFusedMoeMethod(self)
|
||||
else:
|
||||
return BlockWiseFP8LinearMethod(self)
|
||||
|
||||
|
||||
class BlockWiseLinearMethod(QuantMethodBase):
|
||||
class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
block wise quantization method for linear
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: BlockWiseConfig,
|
||||
quant_config: BlockWiseFP8Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_scale = self.create_parameter(
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.embed_dim + QUANT_ALIGNMENT_OFFSET) // QUANT_BLOCK_SIZE,
|
||||
(layer.num_heads * layer.head_dim + QUANT_ALIGNMENT_OFFSET) //
|
||||
QUANT_BLOCK_SIZE,
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] -
|
||||
1) // self.quant_config.weight_block_size[0],
|
||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
||||
// self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
@@ -80,15 +91,30 @@ class BlockWiseLinearMethod(QuantMethodBase):
|
||||
layer.linear_weight.copy_(quanted_weight_tensor, False)
|
||||
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict):
|
||||
"""
|
||||
process_prequanted_weights
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
weight_scale = weight_scale.transpose([1, 0])
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
x, self.quant_config.weight_block_size[0])
|
||||
linear_out = paddle.empty(
|
||||
(x.shape[0], layer.llm_config.model_config.hidden_size),
|
||||
dtype=paddle.bfloat16)
|
||||
linear_out = paddle.empty((x.shape[0], layer.output_size),
|
||||
dtype=paddle.bfloat16)
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.linear_weight, layer.linear_weight_scale),
|
||||
linear_out,
|
||||
)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||
return linear_out
|
@@ -13,38 +13,66 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from paddle import nn
|
||||
import os
|
||||
import paddle
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
from ..utils import create_and_set_parameter
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class KvCacheQuantzationTypes(str, Enum):
|
||||
"""
|
||||
KvCacheQuantzationTypes
|
||||
"""
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
INT8_ZP = "int8_zp"
|
||||
FP8_ZP = "float8_e4m3fn_zp"
|
||||
|
||||
|
||||
class KvCacheQuantConfig(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, cachekv_scale_dict) -> None:
|
||||
def __init__(self, kv_cache_quant_type: str) -> None:
|
||||
"""
|
||||
__init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.cachekv_scale_dict = cachekv_scale_dict
|
||||
self.kv_cache_quant_type = kv_cache_quant_type
|
||||
|
||||
def get_name(self) -> str:
|
||||
try:
|
||||
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
|
||||
|
||||
self.has_zero_point = "zp" in kv_cache_quant_type
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||
self.max_bound = 448.0
|
||||
else:
|
||||
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
get_name
|
||||
"""
|
||||
return "kvcache"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "KvCacheQuantConfig":
|
||||
def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig":
|
||||
"""
|
||||
from_config
|
||||
"""
|
||||
cachekv_scale_dict = config["cachekv_scale_dict"]
|
||||
return cls(cachekv_scale_dict)
|
||||
return cls(kv_cache_quant_type)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
@@ -66,197 +94,63 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
KVCacheMethodBase __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.cache_quant_config = quant_config
|
||||
|
||||
def load_zp(self, layer: nn.Layer):
|
||||
def load_zp(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
load_zp
|
||||
"""
|
||||
if self.cache_k_zp_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_k_zp = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_k_zp_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
else:
|
||||
cache_k_zp = paddle.zeros(
|
||||
(
|
||||
[self.kv_num_heads * self.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [self.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
)
|
||||
if self.cache_v_zp_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_v_zp = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_v_zp_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
else:
|
||||
cache_v_zp = paddle.zeros(
|
||||
(
|
||||
[self.kv_num_heads * self.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [self.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
)
|
||||
layer.cache_k_zp.set_value(cache_k_zp)
|
||||
layer.cache_v_zp.set_value(cache_v_zp)
|
||||
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name))
|
||||
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name))
|
||||
|
||||
def load_scale(self, layer: nn.Layer):
|
||||
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
|
||||
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
|
||||
|
||||
def load_scale(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
load_scale
|
||||
"""
|
||||
if self.cache_k_scale_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_k_scale = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_k_scale_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_k_out_scale = 1.0 / cache_k_scale
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_k_scale_name} not found in scale dict")
|
||||
cache_k_scale_tensor = get_tensor(
|
||||
state_dict.pop(self.cache_k_scale_name)).cast(
|
||||
paddle.get_default_dtype()).reshape_([-1])
|
||||
cache_v_scale_tensor = get_tensor(
|
||||
state_dict.pop(self.cache_v_scale_name)).cast(
|
||||
paddle.get_default_dtype()).reshape_([-1])
|
||||
|
||||
if self.cache_v_scale_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_v_scale = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_v_scale_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_v_out_scale = 1.0 / cache_v_scale
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_v_scale_name} not found in scale dict")
|
||||
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
||||
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
||||
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
|
||||
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
|
||||
|
||||
if self.cache_v_scale_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_v_scale = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_v_scale_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_v_out_scale = 1.0 / cache_v_scale
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_v_scale_name} not found in scale dict")
|
||||
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
|
||||
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
|
||||
create_and_set_parameter(layer, "cache_k_out_scale", cache_k_out_scale)
|
||||
create_and_set_parameter(layer, "cache_v_out_scale", cache_v_out_scale)
|
||||
|
||||
layer.cache_k_scale.set_value(cache_k_scale)
|
||||
layer.cache_v_scale.set_value(cache_v_scale)
|
||||
layer.cache_k_out_scale.set_value(cache_k_out_scale)
|
||||
layer.cache_v_out_scale.set_value(cache_v_out_scale)
|
||||
|
||||
def create_scale(self, layer: nn.Layer):
|
||||
"""
|
||||
create_scale
|
||||
"""
|
||||
layer.cache_k_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_v_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_k_out_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
attr=None,
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_v_out_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
attr=None,
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def create_zp(self, layer: nn.Layer):
|
||||
"""
|
||||
create_zp
|
||||
"""
|
||||
layer.cache_k_zp = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_v_zp = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def create_weights(self, layer: nn.Layer):
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
create_weights
|
||||
"""
|
||||
self.prefix = layer.prefix
|
||||
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_quanter"
|
||||
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_quanter"
|
||||
self.cache_k_zp_name = layer.cache_k_scale_name + ".zero_point"
|
||||
self.cache_v_zp_name = layer.cache_v_scale_name + ".zero_point"
|
||||
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
|
||||
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
|
||||
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
|
||||
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
|
||||
|
||||
layer.cache_k_zp = None
|
||||
layer.cache_v_zp = None
|
||||
layer.cache_k_scale = None
|
||||
layer.cache_v_scale = None
|
||||
layer.cache_k_out_scale = None
|
||||
layer.cache_v_out_scale = None
|
||||
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
|
||||
setattr(layer, "cache_quant_type_str", "cache_int8")
|
||||
setattr(layer, "quant_max_bound", 127.0)
|
||||
setattr(layer, "quant_min_bound", -127.0)
|
||||
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.FP8:
|
||||
setattr(layer, "cache_quant_type_str", "cache_fp8")
|
||||
setattr(layer, "quant_max_bound", 448.0)
|
||||
setattr(layer, "quant_min_bound", -448.0)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
self._dtype = layer._dtype
|
||||
if self._dtype != "bfloat16" and self._dtype != "float16" and self._dtype == "float32":
|
||||
raise ValueError(
|
||||
f"Just support float32, float16 and \
|
||||
bfloat16 as default dtype, but received {self._dtype}"
|
||||
)
|
||||
self.cache_scale_dtype = (
|
||||
self._dtype if self.quant_config.use_append_attn else "float32"
|
||||
)
|
||||
|
||||
if not self.quant_config.use_dynamic_cachekv_quant:
|
||||
if (
|
||||
self.quant_config.cachekv_dtype == "int8"
|
||||
or self.quant_config.cachekv_dtype == "int4"
|
||||
or self.quant_config.cachekv_dtype == "float8_e4m3fn"
|
||||
):
|
||||
self.create_scale(layer)
|
||||
self.load_scale(layer)
|
||||
if self.quant_config.has_zero_point:
|
||||
self.create_zp(layer)
|
||||
self.load_zp(layer)
|
||||
layer.cache_quant_type_str = self.quant_config.cache_quant_type
|
||||
self.load_scale(layer, state_dict)
|
||||
if self.cache_quant_config.has_zero_point:
|
||||
self.load_zp(layer, state_dict)
|
||||
|
||||
def apply(self, layer):
|
||||
"""
|
||||
@@ -264,4 +158,3 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
|
75
fastdeploy/model_executor/layers/quantization/mix_quant.py
Normal file
75
fastdeploy/model_executor/layers/quantization/mix_quant.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..attention import Attention
|
||||
from ..moe import FusedMoE
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class MixQuantConfig(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for layers that has different quantization methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dense_quant_type: str,
|
||||
moe_quant_type: str,
|
||||
kv_cache_quant_type: str = None,
|
||||
image_moe_quant_type: str = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
self.moe_quant_type = moe_quant_type
|
||||
self.kv_cache_quant_type = kv_cache_quant_type
|
||||
if image_moe_quant_type is None:
|
||||
self.image_moe_quant_type = moe_quant_type
|
||||
else:
|
||||
self.image_moe_quant_type = image_moe_quant_type
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "MixQuantConfig":
|
||||
return cls(config['dense_quant_type'], config['moe_quant_type'],
|
||||
config.get('kv_cache_quant_type', None),
|
||||
config.get('image_moe_quant_type', None))
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.moe_tag == "Image":
|
||||
return get_quantization_config(
|
||||
self.image_moe_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
else:
|
||||
return get_quantization_config(
|
||||
self.moe_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
elif isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_type is not None:
|
||||
return (get_quantization_config("kvcache").from_config(
|
||||
self.kv_cache_quant_type).get_quant_method(layer))
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from .cutlass_scaled_mm import cutlass_scaled_mm
|
||||
from .scaled_fp8_quant import scaled_fp8_quant
|
||||
|
||||
__all__ = [
|
||||
"cutlass_scaled_mm",
|
||||
"scaled_fp8_quant",
|
||||
]
|
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
|
||||
|
||||
def cutlass_scaled_mm(a: paddle.Tensor,
|
||||
b: paddle.Tensor,
|
||||
scale_a: paddle.Tensor,
|
||||
scale_b: paddle.Tensor,
|
||||
out_dtype: paddle.dtype,
|
||||
bias: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
`cutlass_scaled_mm` implements a fused version of
|
||||
`output = paddle.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
|
||||
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||
support extended "group" broadcast rules. We extend the numpy-style
|
||||
broadcasting rules with the following rule:
|
||||
"if the extent of a dimension in the source shape is between 1 and
|
||||
corresponding extent in the target shape we repeat each element along
|
||||
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||
example if we have:
|
||||
a = [[1, 2], and target_shape = (2, 4)
|
||||
[3, 4]]
|
||||
then we would expand a to:
|
||||
a = [[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]]
|
||||
currently we only support the case:
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (out_dtype == paddle.bfloat16 or out_dtype == paddle.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
0] and bias.dtype == out_dtype
|
||||
# Ensure input tensors have valid shapes
|
||||
# assert a.numel() > 0, "Input tensor 'a' must not be empty"
|
||||
# assert b.numel() > 0, "Input tensor 'b' must not be empty"
|
||||
# assert scale_a.numel() > 0, "Scale tensor 'scale_a' must not be empty"
|
||||
# assert scale_b.numel() > 0, "Scale tensor 'scale_b' must not be empty"
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[0]
|
||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert cutlass_compatible_b
|
||||
|
||||
out = paddle.empty([m, n], dtype=out_dtype)
|
||||
fastdeploy.model_executor.ops.gpu.cutlass_scaled_mm(
|
||||
out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: paddle.Tensor,
|
||||
scale: Optional[paddle.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: float = 0,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function supports both static and dynamic quantization: If you
|
||||
provide the scale, it will use static scaling and if you omit it,
|
||||
the scale will be determined dynamically. The function also allows
|
||||
optional padding of the output tensors for downstream kernels that
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, paddle.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = paddle.empty(shape, dtype=paddle.float8_e4m3fn)
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_per_token_scaled_fp8_quant
|
||||
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = paddle.zeros([1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_scaled_fp8_quant
|
||||
dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
# assert (scale.numel() == 1 or num_token_padding is None)
|
||||
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
|
||||
static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: paddle.Tensor,
|
||||
scale: Optional[paddle.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: float = 0,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function supports both static and dynamic quantization: If you
|
||||
provide the scale, it will use static scaling and if you omit it,
|
||||
the scale will be determined dynamically. The function also allows
|
||||
optional padding of the output tensors for downstream kernels that
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, paddle.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = paddle.empty(shape, dtype=paddle.float8_e4m3fn)
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_per_token_scaled_fp8_quant
|
||||
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = paddle.zeros([1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_scaled_fp8_quant
|
||||
dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
# assert (scale.numel() == 1 or num_token_padding is None)
|
||||
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
|
||||
static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
@@ -47,12 +47,9 @@ class QuantConfigBase(ABC):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.quant_round_type = None
|
||||
self.quant_max_bound = None
|
||||
self.quant_min_bound = None
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
135
fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py
Normal file
135
fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class TensorWiseFP8Config(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for weight and activation with FP8.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Nothing else to do!
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Nothing else to do!
|
||||
"""
|
||||
return "tensor_wise_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "TensorWiseFP8Config":
|
||||
"""
|
||||
Nothing else to do!
|
||||
"""
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
return method according to this config!
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
TensorWiseFP8MoEMethod
|
||||
return TensorWiseFP8MoEMethod(self)
|
||||
else:
|
||||
return TensorWiseFP8LinearMethod(self)
|
||||
|
||||
|
||||
class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
Weight and activation quantization method for linear layer with per tensor FP8
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: TensorWiseFP8Config,
|
||||
) -> None:
|
||||
"""
|
||||
Nothing special to do!
|
||||
"""
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.weight_dtype = "float8_e4m3fn"
|
||||
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
Nothing to do!
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
self.act_scale = act_scale.item()
|
||||
self.total_scale = (act_scale * weight_scale).item()
|
||||
|
||||
def process_loaded_weights(self, layer, weights, state_dict) -> None:
|
||||
"""
|
||||
Read fp8 weight, act scale, weight scale
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
compute!
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
cutlass_fp8_fp8_half_gemm_fused
|
||||
|
||||
from ..utils import create_hadamard_matrix_map
|
||||
|
||||
hadamard_matrix = create_hadamard_matrix_map[x.shape[-1]]
|
||||
new_x = paddle.matmul(x.cast("float32"), hadamard_matrix)
|
||||
fp8_x = new_x / self.act_scale
|
||||
fp8_x = fp8_x.astype("float8_e4m3fn")
|
||||
|
||||
linear_out = cutlass_fp8_fp8_half_gemm_fused(
|
||||
fp8_x,
|
||||
layer.linear_weight,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
bias=None,
|
||||
scale=self.total_scale,
|
||||
output_dtype="bfloat16",
|
||||
activation_type="identity")
|
||||
return linear_out
|
42
fastdeploy/model_executor/layers/quantization/w4a8.py
Normal file
42
fastdeploy/model_executor/layers/quantization/w4a8.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class W4A8Config(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight 4bits and activation 8bits
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4a8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "W4A8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import CutlassW4A8MoEMethod
|
||||
return CutlassW4A8MoEMethod(self)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(layer)} for w4a8")
|
@@ -23,16 +23,21 @@ from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
QUANT_SCALING_FACTOR = 448
|
||||
|
||||
|
||||
class W4AFP8Config(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
return "w4afp8"
|
||||
|
||||
@classmethod
|
||||
@@ -49,6 +54,7 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
W4 AFP8 quant method for linear
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: W4AFP8Config,
|
||||
@@ -57,6 +63,9 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
pass
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
@@ -78,11 +87,11 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
layer.linear_weight_scale,
|
||||
zero_points=None,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
out_scale=self.quant_config.weight_scale_dict.get(
|
||||
layer.prefix + ".weight_quanter") /
|
||||
(self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_quanter") *
|
||||
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
|
||||
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
/ (self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_scale") *
|
||||
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
|
||||
groupsize=0,
|
||||
out_dtype=layer._dtype,
|
||||
)
|
||||
|
@@ -16,11 +16,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddlenlp.utils.log import logger
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
|
||||
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
@@ -29,14 +30,18 @@ class W8A8Config(QuantConfigBase):
|
||||
quantization config for weight 8bits and activation 8bits
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict,
|
||||
use_gemm_dequant) -> None:
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, use_gemm_dequant,
|
||||
use_smooth_quant) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.use_gemm_dequant = use_gemm_dequant
|
||||
self.use_smooth_quant = use_smooth_quant
|
||||
self.quant_max_bound = 127
|
||||
self.quant_min_bound = -127
|
||||
self.quant_round_type = 0
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
return "w8a8"
|
||||
|
||||
@classmethod
|
||||
@@ -61,12 +66,17 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
|
||||
|
||||
def create_weights(self, layer):
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(
|
||||
layer.prefix + ".weight_quanter")
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_dtype = "int8"
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.create_weights(layer)
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_quanter")
|
||||
".activation_scale")
|
||||
self.skip_quant = False
|
||||
if weight_scale is None or in_scale is None:
|
||||
self.skip_quant = True
|
||||
@@ -86,13 +96,15 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
convert_to_npu_dequant_scale(linear_out_scale))
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.process_loaded_weights(layer, weights)
|
||||
if self.skip_quant:
|
||||
logger.debug(f"{layer.prefix} skip quant")
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
else:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
weight_tensor = paddle.cast(weight_tensor, layer.weight_dtype)
|
||||
weight_tensor = paddle.cast(weight_tensor, "int8")
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
@@ -107,3 +119,53 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
|
||||
linear_out, layer.linear_out_scale, layer._dtype)
|
||||
return linear_out
|
||||
|
||||
|
||||
class SmoothQuantLinearMethod(QuantMethodBase):
|
||||
"""
|
||||
SmoothQuant Method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: QuantConfigBase,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
linear_shift_shape = [layer.output_size]
|
||||
linear_smooth_shape = [layer.output_size]
|
||||
layer.linear_shift = self.create_parameter(
|
||||
shape=linear_shift_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.linear_smooth = layer.create_parameter(
|
||||
shape=linear_smooth_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
if layer.shift_key in layer.state_dict:
|
||||
shift_tensor = get_tensor(layer.state_dict.pop(
|
||||
layer.shift_key)).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=layer.linear_shift_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
layer.linear_shift.set_value(shift_tensor)
|
||||
if layer.smooth_key in layer.state_dict:
|
||||
smooth_tensor = get_tensor(layer.state_dict.pop(
|
||||
layer.smooth_key)).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[layer.linear_smooth_shape],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
layer.linear_smooth.set_value(smooth_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
pass
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@@ -21,6 +22,8 @@ from paddle.nn.quant import weight_only_linear, weight_quantize
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
@@ -28,34 +31,92 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for weight only
|
||||
Args:
|
||||
weight_only_linear_arch: The architecture of weight only linear layer
|
||||
algo: The quant algorithm("weight_only_int8" or "weight_only_int4") used for weight only linear layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_only_linear_arch: int,
|
||||
algo: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_only_linear_arch = weight_only_linear_arch
|
||||
self.algo = algo
|
||||
# arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70,
|
||||
# if you do not assign arch, we will get arch from your device, default: None.
|
||||
self.weight_only_linear_arch = os.getenv(
|
||||
"FLAGS_weight_only_linear_arch")
|
||||
if self.weight_only_linear_arch is not None:
|
||||
self.weight_only_linear_arch = int(self.weight_only_linear_arch)
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
return "weight_only"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||
weight_only_linear_arch = config["weight_only_linear_arch"]
|
||||
algo = config["algo"]
|
||||
return cls(weight_only_linear_arch, algo)
|
||||
return cls(algo)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.layers.backends import XPUWeightOnlyLinearMethod
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return XPUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.use_method == "cutlass":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import \
|
||||
CutlassWeightOnlyMoEMethod
|
||||
return CutlassWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "triton":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
TritonWeightOnlyMoEMethod
|
||||
return TritonWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "marlin":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import \
|
||||
MarlinWeightOnlyMoEMethod
|
||||
return MarlinWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported MOE backend {layer.use_method}")
|
||||
else:
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
|
||||
|
||||
class WINT8Config(WeightOnlyConfig):
|
||||
"""
|
||||
weight only int8 config
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__("weight_only_int8")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT8Config":
|
||||
return cls()
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint8"
|
||||
|
||||
|
||||
class WINT4Config(WeightOnlyConfig):
|
||||
"""
|
||||
weight only int4 config
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__("weight_only_int4")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT4Config":
|
||||
return cls()
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint4"
|
||||
|
||||
|
||||
class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
@@ -71,12 +132,17 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
weight_only_scale_name = layer.prefix + ".weight_only_scale"
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
if self.quant_config.name() == "wint4":
|
||||
linear_weight_scale_shape[0] *= 2
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
@@ -94,7 +160,8 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
weight=layer.linear_weight,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
weight_scale=layer.linear_weight_scale,
|
||||
weight_dtype=layer.weight_dtype,
|
||||
weight_dtype="int8"
|
||||
if self.quant_config.name() == "wint8" else "int4",
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
return linear_out
|
||||
@@ -113,6 +180,20 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||
weight,
|
||||
|
@@ -17,10 +17,10 @@ from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
|
||||
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
cutlass_scaled_mm, scaled_fp8_quant)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import (
|
||||
QuantConfigBase, QuantMethodBase)
|
||||
|
||||
|
||||
class WFP8AFP8Config(QuantConfigBase):
|
||||
@@ -32,17 +32,26 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
"""
|
||||
"""
|
||||
return "wfp8afp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||
weight_scale_dict = config["weight_scale_dict"]
|
||||
act_scale_dict = config["act_scale_dict"]
|
||||
"""
|
||||
"""
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
return cls(weight_scale_dict, act_scale_dict)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
"""
|
||||
return WFP8AFP8LinearMethod(self)
|
||||
|
||||
|
||||
@@ -59,58 +68,49 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
"""
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(
|
||||
layer.prefix + ".weight_quanter")
|
||||
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_quanter")
|
||||
self.skip_quant = False
|
||||
# we will skip quant if weight_scale is not found or in_scale is not found
|
||||
if weight_scale is None or in_scale is None:
|
||||
self.skip_quant = True
|
||||
else:
|
||||
max_range = 448.0
|
||||
layer.scalar_scale_name = layer.prefix + ".scalar_weight_quanter"
|
||||
layer.scalar_scale = layer.create_parameter(
|
||||
shape=([1]),
|
||||
dtype="float32",
|
||||
)
|
||||
layer.scalar_scale.set_value(
|
||||
paddle.to_tensor([1.0 / (max_range * in_scale)],
|
||||
dtype="float32"))
|
||||
linear_out_scale = paddle.to_tensor(weight_scale /
|
||||
max_range).astype("float32")
|
||||
layer.linear_out_scale = layer.create_parameter(
|
||||
shape=[layer.embed_dim],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.linear_out_scale.set_value(
|
||||
convert_to_npu_dequant_scale(linear_out_scale))
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
# TODO(YuanRisheng): We should abstract the skip_quant logic to adapt to more quant methods
|
||||
"""
|
||||
"""
|
||||
if self.skip_quant:
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
return
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
weight_tensor = paddle.cast(weight_tensor, self.weight_dtype)
|
||||
self.linear_weight.copy_(weight_tensor, False)
|
||||
if weights.dtype != paddle.float8_e4m3fn:
|
||||
self.use_per_token_if_dynamic = True
|
||||
weight_tensor = weights.transpose([1, 0]).contiguous()
|
||||
qweight, weight_scale = scaled_fp8_quant(
|
||||
weight_tensor,
|
||||
use_per_token_if_dynamic=False,
|
||||
)
|
||||
layer.linear_weight.copy_(qweight, False)
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
"""
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
return linear_out
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.per_channel_fp8_fp8_half_gemm_fused(
|
||||
x,
|
||||
layer.linear_weight,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
scalar_scale=layer.scalar_scale,
|
||||
channel_scale=layer.linear_out_scale,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
output_dtype=layer._dtype,
|
||||
)
|
||||
if self.use_per_token_if_dynamic:
|
||||
out_type = x.dtype
|
||||
a_q, a_scales = scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
|
||||
layer.linear_weight_scale, out_type,
|
||||
layer.linear_bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return linear_out
|
||||
|
142
fastdeploy/model_executor/layers/quantization/wint2.py
Normal file
142
fastdeploy/model_executor/layers/quantization/wint2.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class WINT2Config(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for wint8 linear and w4w2 MoE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dense_quant_type: str,
|
||||
dense_quant_granularity: str,
|
||||
moe_quant_type: str,
|
||||
moe_w4_quant_type: str,
|
||||
moe_w4_quant_granularity: str,
|
||||
moe_w4_quant_start_layer: int,
|
||||
moe_w4_quant_end_layer: int,
|
||||
moe_w2_quant_type: str,
|
||||
moe_w2_quant_granularity: str,
|
||||
moe_w2_quant_group_size: int,
|
||||
moe_w2_quant_start_layer: int,
|
||||
moe_w2_quant_end_layer: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
|
||||
# wint2 quantization config
|
||||
self.dense_quant_type = dense_quant_type
|
||||
self.dense_quant_granularity = dense_quant_granularity
|
||||
self.moe_quant_type = moe_quant_type
|
||||
self.moe_w4_quant_type = moe_w4_quant_type
|
||||
self.moe_w4_quant_granularity = moe_w4_quant_granularity
|
||||
self.moe_w4_quant_start_layer = moe_w4_quant_start_layer
|
||||
self.moe_w4_quant_end_layer = moe_w4_quant_end_layer
|
||||
self.moe_w2_quant_type = moe_w2_quant_type
|
||||
self.moe_w2_quant_granularity = moe_w2_quant_granularity
|
||||
self.moe_w2_quant_group_size = moe_w2_quant_group_size
|
||||
self.moe_w2_quant_start_layer = moe_w2_quant_start_layer
|
||||
self.moe_w2_quant_end_layer = moe_w2_quant_end_layer
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name of the quantization configuration.
|
||||
Returns:
|
||||
str: The name of the quantization configuration.
|
||||
"""
|
||||
return "wint2"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT2Config":
|
||||
"""
|
||||
Create a new instance of `WINT2Config` using the provided configuration dictionary.
|
||||
Args:
|
||||
config (dict): A dictionary containing the configuration parameters for the new instance.
|
||||
|
||||
Returns:
|
||||
WINT2Config: The newly created instance of `WINT2Config`.
|
||||
"""
|
||||
|
||||
dense_quant_type = config.get("dense_quant_config", "wint8")
|
||||
dense_quant_granularity = config.get("dense_quant_granularity",
|
||||
"per_channel")
|
||||
|
||||
moe_quant_config = config.get("moe_quant_config", {})
|
||||
moe_quant_type = moe_quant_config.get("quant_type", "w4w2")
|
||||
|
||||
moe_w4_quant_config = moe_quant_config.get("moe_w4_quant_config", {})
|
||||
moe_w4_quant_type = moe_w4_quant_config.get("quant_type",
|
||||
"wint4")
|
||||
moe_w4_quant_granularity = moe_w4_quant_config.get(
|
||||
"quant_granularity", "per_channel")
|
||||
moe_w4_quant_start_layer = moe_w4_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w4_quant_end_layer = moe_w4_quant_config.get("quant_end_layer", 6)
|
||||
|
||||
moe_w2_quant_config = moe_quant_config.get("moe_w2_quant_config", {})
|
||||
moe_w2_quant_type = moe_w2_quant_config.get("quant_type", "wint2")
|
||||
moe_w2_quant_granularity = moe_w2_quant_config.get(
|
||||
"quant_granularity", "pp_acc")
|
||||
moe_w2_quant_group_size = moe_w2_quant_config.get(
|
||||
"quant_group_size", 0)
|
||||
moe_w2_quant_start_layer = moe_w2_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w2_quant_end_layer = moe_w2_quant_config.get("quant_end_layer", 0)
|
||||
|
||||
return cls(
|
||||
dense_quant_type,
|
||||
dense_quant_granularity,
|
||||
moe_quant_type,
|
||||
moe_w4_quant_type,
|
||||
moe_w4_quant_granularity,
|
||||
moe_w4_quant_start_layer,
|
||||
moe_w4_quant_end_layer,
|
||||
moe_w2_quant_type,
|
||||
moe_w2_quant_granularity,
|
||||
moe_w2_quant_group_size,
|
||||
moe_w2_quant_start_layer,
|
||||
moe_w2_quant_end_layer,
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
Get the quantization method associated with the given layer based on the current quantization configuration.
|
||||
Args:
|
||||
layer (Layer): The layer for which the quantization method should be retrieved.
|
||||
|
||||
Returns:
|
||||
QuantMethodBase: The quantization method associated with the given layer.
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.layer_idx <= self.moe_w4_quant_end_layer:
|
||||
return get_quantization_config(
|
||||
self.moe_w4_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
|
||||
TritonWint2FusedMoeMethod
|
||||
return TritonWint2FusedMoeMethod(self)
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
Reference in New Issue
Block a user