mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
@@ -49,17 +50,20 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
return cls(weight_block_size)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
'''
|
||||
"""
|
||||
Get quantization method.
|
||||
'''
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
if self.use_deep_gemm:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
|
||||
DeepGemmFusedMoeMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
|
||||
DeepGemmFusedMoeMethod,
|
||||
)
|
||||
|
||||
return DeepGemmFusedMoeMethod(self)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
BlockWiseFP8MoEMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
BlockWiseFP8MoEMethod,
|
||||
)
|
||||
return BlockWiseFP8MoEMethod(self)
|
||||
else:
|
||||
return BlockWiseFP8LinearMethod(self)
|
||||
@@ -81,8 +85,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] -
|
||||
1) // self.quant_config.weight_block_size[0],
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
||||
// self.quant_config.weight_block_size[0],
|
||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
||||
// self.quant_config.weight_block_size[1],
|
||||
],
|
||||
@@ -93,8 +97,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = (
|
||||
per_block_cast_to_fp8(weight_tensor))
|
||||
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale.set_value(weight_block_scale_tensor)
|
||||
|
||||
@@ -113,10 +116,11 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
x, self.quant_config.weight_block_size[0])
|
||||
linear_out = paddle.empty((x.shape[0], layer.output_size),
|
||||
dtype=paddle.bfloat16)
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
)
|
||||
linear_out = paddle.empty((x.shape[0], layer.output_size), dtype=paddle.bfloat16)
|
||||
from fastdeploy.model_executor.ops.gpu import deep_gemm
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.weight, layer.weight_scale),
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@@ -29,6 +30,7 @@ class KvCacheQuantzationTypes(str, Enum):
|
||||
"""
|
||||
KvCacheQuantzationTypes
|
||||
"""
|
||||
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
INT8_ZP = "int8_zp"
|
||||
@@ -50,7 +52,7 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
try:
|
||||
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
|
||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||
|
||||
self.has_zero_point = "zp" in kv_cache_quant_type
|
||||
|
||||
@@ -59,7 +61,7 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||
self.max_bound = 448.0
|
||||
else:
|
||||
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
|
||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
@@ -110,12 +112,12 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
load_scale
|
||||
"""
|
||||
cache_k_scale_tensor = get_tensor(
|
||||
state_dict.pop(self.cache_k_scale_name)).cast(
|
||||
paddle.get_default_dtype()).reshape_([-1])
|
||||
cache_v_scale_tensor = get_tensor(
|
||||
state_dict.pop(self.cache_v_scale_name)).cast(
|
||||
paddle.get_default_dtype()).reshape_([-1])
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
cache_v_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
|
||||
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
||||
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
||||
@@ -138,13 +140,13 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
|
||||
|
||||
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
|
||||
setattr(layer, "cache_quant_type_str", "cache_int8")
|
||||
setattr(layer, "quant_max_bound", 127.0)
|
||||
setattr(layer, "quant_min_bound", -127.0)
|
||||
layer.cache_quant_type_str = "cache_int8"
|
||||
layer.quant_max_bound = 127.0
|
||||
layer.quant_min_bound = -127.0
|
||||
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.FP8:
|
||||
setattr(layer, "cache_quant_type_str", "cache_fp8")
|
||||
setattr(layer, "quant_max_bound", 448.0)
|
||||
setattr(layer, "quant_min_bound", -448.0)
|
||||
layer.cache_quant_type_str = "cache_fp8"
|
||||
layer.quant_max_bound = 448.0
|
||||
layer.quant_min_bound = -448.0
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
@@ -156,5 +158,4 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
apply
|
||||
"""
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.apply should not be called.")
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
@@ -51,26 +52,23 @@ class MixQuantConfig(QuantConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "MixQuantConfig":
|
||||
return cls(config['dense_quant_type'], config['moe_quant_type'],
|
||||
config.get('kv_cache_quant_type', None),
|
||||
config.get('image_moe_quant_type', None))
|
||||
return cls(
|
||||
config["dense_quant_type"],
|
||||
config["moe_quant_type"],
|
||||
config.get("kv_cache_quant_type", None),
|
||||
config.get("image_moe_quant_type", None),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.moe_tag == "Image":
|
||||
return get_quantization_config(
|
||||
self.image_moe_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.image_moe_quant_type).from_config({}).get_quant_method(layer)
|
||||
else:
|
||||
return get_quantization_config(
|
||||
self.moe_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer)
|
||||
elif isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_type is not None:
|
||||
return (get_quantization_config("kvcache").from_config(
|
||||
self.kv_cache_quant_type).get_quant_method(layer))
|
||||
return get_quantization_config("kvcache").from_config(self.kv_cache_quant_type).get_quant_method(layer)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer)
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .cutlass_scaled_mm import cutlass_scaled_mm
|
||||
from .scaled_fp8_quant import scaled_fp8_quant
|
||||
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
@@ -20,12 +21,14 @@ import paddle
|
||||
import fastdeploy
|
||||
|
||||
|
||||
def cutlass_scaled_mm(a: paddle.Tensor,
|
||||
b: paddle.Tensor,
|
||||
scale_a: paddle.Tensor,
|
||||
scale_b: paddle.Tensor,
|
||||
out_dtype: paddle.dtype,
|
||||
bias: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
def cutlass_scaled_mm(
|
||||
a: paddle.Tensor,
|
||||
b: paddle.Tensor,
|
||||
scale_a: paddle.Tensor,
|
||||
scale_b: paddle.Tensor,
|
||||
out_dtype: paddle.dtype,
|
||||
bias: Optional[paddle.Tensor] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
`cutlass_scaled_mm` implements a fused version of
|
||||
`output = paddle.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
@@ -48,9 +51,8 @@ def cutlass_scaled_mm(a: paddle.Tensor,
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (out_dtype == paddle.bfloat16 or out_dtype == paddle.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
0] and bias.dtype == out_dtype
|
||||
assert out_dtype == paddle.bfloat16 or out_dtype == paddle.float16
|
||||
assert bias is None or bias.shape[0] == b.shape[0] and bias.dtype == out_dtype
|
||||
# Ensure input tensors have valid shapes
|
||||
# assert a.numel() > 0, "Input tensor 'a' must not be empty"
|
||||
# assert b.numel() > 0, "Input tensor 'b' must not be empty"
|
||||
@@ -59,12 +61,11 @@ def cutlass_scaled_mm(a: paddle.Tensor,
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[0]
|
||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
||||
assert cutlass_compatible_b
|
||||
|
||||
out = paddle.empty([m, n], dtype=out_dtype)
|
||||
fastdeploy.model_executor.ops.gpu.cutlass_scaled_mm(
|
||||
out, a, b, scale_a, scale_b, bias)
|
||||
fastdeploy.model_executor.ops.gpu.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
return out
|
||||
|
||||
@@ -100,7 +101,7 @@ def scaled_fp8_quant(
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
assert input.ndim == 2
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
@@ -109,18 +110,21 @@ def scaled_fp8_quant(
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_per_token_scaled_fp8_quant
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
dynamic_per_token_scaled_fp8_quant,
|
||||
)
|
||||
|
||||
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = paddle.zeros([1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_scaled_fp8_quant
|
||||
from fastdeploy.model_executor.ops.gpu import dynamic_scaled_fp8_quant
|
||||
|
||||
dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
# assert (scale.numel() == 1 or num_token_padding is None)
|
||||
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
|
||||
|
||||
static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
@@ -49,7 +50,7 @@ def scaled_fp8_quant(
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
assert input.ndim == 2
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
@@ -58,18 +59,21 @@ def scaled_fp8_quant(
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_per_token_scaled_fp8_quant
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
dynamic_per_token_scaled_fp8_quant,
|
||||
)
|
||||
|
||||
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = paddle.zeros([1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_scaled_fp8_quant
|
||||
from fastdeploy.model_executor.ops.gpu import dynamic_scaled_fp8_quant
|
||||
|
||||
dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
# assert (scale.numel() == 1 or num_token_padding is None)
|
||||
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
|
||||
|
||||
static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -65,8 +66,7 @@ class QuantConfigBase(ABC):
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(self, layer, prefix) -> Optional[QuantMethodBase]:
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
@@ -50,8 +51,10 @@ class TensorWiseFP8Config(QuantConfigBase):
|
||||
return method according to this config!
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
TensorWiseFP8MoEMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
TensorWiseFP8MoEMethod,
|
||||
)
|
||||
|
||||
return TensorWiseFP8MoEMethod(self)
|
||||
else:
|
||||
return TensorWiseFP8LinearMethod(self)
|
||||
@@ -112,7 +115,9 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
compute!
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cutlass_fp8_fp8_half_gemm_fused, fused_hadamard_quant_fp8)
|
||||
cutlass_fp8_fp8_half_gemm_fused,
|
||||
fused_hadamard_quant_fp8,
|
||||
)
|
||||
|
||||
fp8_x = fused_hadamard_quant_fp8(x, scale=self.act_scale)
|
||||
|
||||
@@ -124,5 +129,6 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
bias=None,
|
||||
scale=self.total_scale,
|
||||
output_dtype="bfloat16",
|
||||
activation_type="identity")
|
||||
activation_type="identity",
|
||||
)
|
||||
return linear_out
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
@@ -36,7 +37,10 @@ class W4A8Config(QuantConfigBase):
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import CutlassW4A8MoEMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||
CutlassW4A8MoEMethod,
|
||||
)
|
||||
|
||||
return CutlassW4A8MoEMethod(self)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(layer)} for w4a8")
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
@@ -69,13 +70,14 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
pass
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
quanted_weight_tensor, weight_scale_tensor = (
|
||||
fastdeploy.model_executor.ops.gpu.
|
||||
scaled_gemm_f8_i4_f16_weight_quantize(
|
||||
paddle.cast(weights, "float32").cpu(),
|
||||
groupsize=-1,
|
||||
scale_dtype="float16",
|
||||
))
|
||||
(
|
||||
quanted_weight_tensor,
|
||||
weight_scale_tensor,
|
||||
) = fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize(
|
||||
paddle.cast(weights, "float32").cpu(),
|
||||
groupsize=-1,
|
||||
scale_dtype="float16",
|
||||
)
|
||||
weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype)
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
@@ -87,11 +89,12 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
layer.weight_scale,
|
||||
zero_points=None,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
/ (self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_scale") *
|
||||
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
|
||||
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix + ".weight_scale")
|
||||
/ (
|
||||
self.quant_config.act_scale_dict.get(layer.prefix + ".activation_scale")
|
||||
* QUANT_SCALING_FACTOR
|
||||
* QUANT_SCALING_FACTOR
|
||||
),
|
||||
groupsize=0,
|
||||
out_dtype=layer._dtype,
|
||||
)
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
@@ -30,8 +31,13 @@ class W8A8Config(QuantConfigBase):
|
||||
quantization config for weight 8bits and activation 8bits
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, use_gemm_dequant,
|
||||
use_smooth_quant) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
weight_scale_dict,
|
||||
act_scale_dict,
|
||||
use_gemm_dequant,
|
||||
use_smooth_quant,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
@@ -73,27 +79,22 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
layer.weight_dtype = "int8"
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.create_weights(layer)
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_scale")
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(layer.prefix + ".weight_scale")
|
||||
in_scale = self.quant_config.act_scale_dict.get(layer.prefix + ".activation_scale")
|
||||
self.skip_quant = False
|
||||
if weight_scale is None or in_scale is None:
|
||||
self.skip_quant = True
|
||||
return
|
||||
|
||||
max_range = 127.0
|
||||
linear_out_scale = paddle.to_tensor(
|
||||
weight_scale /
|
||||
(max_range * max_range * in_scale)).astype("float32")
|
||||
linear_out_scale = paddle.to_tensor(weight_scale / (max_range * max_range * in_scale)).astype("float32")
|
||||
layer.linear_out_scale = layer.create_parameter(
|
||||
shape=[layer.embed_dim],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.linear_out_scale.set_value(
|
||||
convert_to_npu_dequant_scale(linear_out_scale))
|
||||
layer.linear_out_scale.set_value(convert_to_npu_dequant_scale(linear_out_scale))
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
if self.quant_config.use_smooth_quant:
|
||||
@@ -113,11 +114,13 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
return linear_out
|
||||
if self.quant_config.use_gemm_dequant:
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.gemm_dequant(
|
||||
x, layer.weight, layer.linear_out_scale, layer._dtype)
|
||||
x, layer.weight, layer.linear_out_scale, layer._dtype
|
||||
)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
|
||||
linear_out, layer.linear_out_scale, layer._dtype)
|
||||
linear_out, layer.linear_out_scale, layer._dtype
|
||||
)
|
||||
return linear_out
|
||||
|
||||
|
||||
@@ -149,8 +152,7 @@ class SmoothQuantLinearMethod(QuantMethodBase):
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
if layer.shift_key in layer.state_dict:
|
||||
shift_tensor = get_tensor(layer.state_dict.pop(
|
||||
layer.shift_key)).astype(paddle.get_default_dtype())
|
||||
shift_tensor = get_tensor(layer.state_dict.pop(layer.shift_key)).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=layer.linear_shift_shape,
|
||||
@@ -158,8 +160,7 @@ class SmoothQuantLinearMethod(QuantMethodBase):
|
||||
)
|
||||
layer.linear_shift.set_value(shift_tensor)
|
||||
if layer.smooth_key in layer.state_dict:
|
||||
smooth_tensor = get_tensor(layer.state_dict.pop(
|
||||
layer.smooth_key)).astype(paddle.get_default_dtype())
|
||||
smooth_tensor = get_tensor(layer.state_dict.pop(layer.smooth_key)).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[layer.linear_smooth_shape],
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
@@ -42,8 +43,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
self.algo = algo
|
||||
# arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70,
|
||||
# if you do not assign arch, we will get arch from your device, default: None.
|
||||
self.weight_only_linear_arch = os.getenv(
|
||||
"FLAGS_weight_only_linear_arch")
|
||||
self.weight_only_linear_arch = os.getenv("FLAGS_weight_only_linear_arch")
|
||||
if self.weight_only_linear_arch is not None:
|
||||
self.weight_only_linear_arch = int(self.weight_only_linear_arch)
|
||||
self.quant_max_bound = 0
|
||||
@@ -60,47 +60,62 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
XPUWeightOnlyLinearMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_xpu_backend import \
|
||||
XPUWeightOnlyMoEMethod
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
XPUWeightOnlyLinearMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_xpu_backend import (
|
||||
XPUWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return XPUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
GCUWeightOnlyLinearMethod, GCUWeightOnlyMoEMethod)
|
||||
GCUWeightOnlyLinearMethod,
|
||||
GCUWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GCUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return GCUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_dcu():
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
DCUTritonWeightOnlyMoEMethod
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
DCUTritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return DCUTritonWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
DCUWeightOnlyLinearMethod
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
DCUWeightOnlyLinearMethod,
|
||||
)
|
||||
|
||||
return DCUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.use_method == "cutlass":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import \
|
||||
CutlassWeightOnlyMoEMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||
CutlassWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return CutlassWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "triton":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
TritonWeightOnlyMoEMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
TritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return TritonWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "marlin":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import \
|
||||
MarlinWeightOnlyMoEMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import (
|
||||
MarlinWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return MarlinWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported MOE backend {layer.use_method}")
|
||||
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
|
||||
else:
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
|
||||
@@ -110,7 +125,9 @@ class WINT8Config(WeightOnlyConfig):
|
||||
weight only int8 config
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__("weight_only_int8")
|
||||
|
||||
@classmethod
|
||||
@@ -126,7 +143,9 @@ class WINT4Config(WeightOnlyConfig):
|
||||
weight only int4 config
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__("weight_only_int4")
|
||||
|
||||
@classmethod
|
||||
@@ -174,8 +193,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
weight=layer.weight,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_dtype="int8"
|
||||
if self.quant_config.name() == "wint8" else "int4",
|
||||
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
return linear_out
|
||||
@@ -205,8 +223,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.weight.set_value(quant_weight)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
|
||||
@@ -217,5 +234,4 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
)
|
||||
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
@@ -13,14 +13,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
cutlass_scaled_mm, scaled_fp8_quant)
|
||||
cutlass_scaled_mm,
|
||||
scaled_fp8_quant,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import (
|
||||
QuantConfigBase, QuantMethodBase)
|
||||
QuantConfigBase,
|
||||
QuantMethodBase,
|
||||
)
|
||||
|
||||
|
||||
class WFP8AFP8Config(QuantConfigBase):
|
||||
@@ -37,21 +42,18 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
self.quant_round_type = 1
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
return "wfp8afp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
return cls(weight_scale_dict, act_scale_dict)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
return WFP8AFP8LinearMethod(self)
|
||||
|
||||
|
||||
@@ -68,8 +70,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
@@ -82,8 +83,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
if self.skip_quant:
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
@@ -99,18 +99,21 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
return linear_out
|
||||
if self.use_per_token_if_dynamic:
|
||||
out_type = x.dtype
|
||||
a_q, a_scales = scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.weight, a_scales,
|
||||
layer.weight_scale, out_type,
|
||||
layer.bias)
|
||||
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
linear_out = cutlass_scaled_mm(
|
||||
a_q,
|
||||
layer.weight,
|
||||
a_scales,
|
||||
layer.weight_scale,
|
||||
out_type,
|
||||
layer.bias,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return linear_out
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
@@ -79,29 +80,22 @@ class WINT2Config(QuantConfigBase):
|
||||
"""
|
||||
|
||||
dense_quant_type = config.get("dense_quant_config", "wint8")
|
||||
dense_quant_granularity = config.get("dense_quant_granularity",
|
||||
"per_channel")
|
||||
dense_quant_granularity = config.get("dense_quant_granularity", "per_channel")
|
||||
|
||||
moe_quant_config = config.get("moe_quant_config", {})
|
||||
moe_quant_type = moe_quant_config.get("quant_type", "w4w2")
|
||||
|
||||
moe_w4_quant_config = moe_quant_config.get("moe_w4_quant_config", {})
|
||||
moe_w4_quant_type = moe_w4_quant_config.get("quant_type",
|
||||
"wint4")
|
||||
moe_w4_quant_granularity = moe_w4_quant_config.get(
|
||||
"quant_granularity", "per_channel")
|
||||
moe_w4_quant_start_layer = moe_w4_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w4_quant_type = moe_w4_quant_config.get("quant_type", "wint4")
|
||||
moe_w4_quant_granularity = moe_w4_quant_config.get("quant_granularity", "per_channel")
|
||||
moe_w4_quant_start_layer = moe_w4_quant_config.get("quant_start_layer", 0)
|
||||
moe_w4_quant_end_layer = moe_w4_quant_config.get("quant_end_layer", 6)
|
||||
|
||||
moe_w2_quant_config = moe_quant_config.get("moe_w2_quant_config", {})
|
||||
moe_w2_quant_type = moe_w2_quant_config.get("quant_type", "wint2")
|
||||
moe_w2_quant_granularity = moe_w2_quant_config.get(
|
||||
"quant_granularity", "pp_acc")
|
||||
moe_w2_quant_group_size = moe_w2_quant_config.get(
|
||||
"quant_group_size", 0)
|
||||
moe_w2_quant_start_layer = moe_w2_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w2_quant_granularity = moe_w2_quant_config.get("quant_granularity", "pp_acc")
|
||||
moe_w2_quant_group_size = moe_w2_quant_config.get("quant_group_size", 0)
|
||||
moe_w2_quant_start_layer = moe_w2_quant_config.get("quant_start_layer", 0)
|
||||
moe_w2_quant_end_layer = moe_w2_quant_config.get("quant_end_layer", 0)
|
||||
|
||||
return cls(
|
||||
@@ -130,13 +124,12 @@ class WINT2Config(QuantConfigBase):
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.layer_idx <= self.moe_w4_quant_end_layer:
|
||||
return get_quantization_config(
|
||||
self.moe_w4_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.moe_w4_quant_type).from_config({}).get_quant_method(layer)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
|
||||
CutlassWint2FusedMoeMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import (
|
||||
CutlassWint2FusedMoeMethod,
|
||||
)
|
||||
|
||||
return CutlassWint2FusedMoeMethod(self)
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer)
|
||||
|
Reference in New Issue
Block a user