[Optimize]support machete weight only gemm (#3561)

* support machete weight only gemm

* add generate

* update

* fix

* change file location

* add sm_version limit

* fix

* fix

* fix ci

* fix coverage

* fix xpu
This commit is contained in:
Sunny-bot1
2025-08-28 09:49:58 +08:00
committed by GitHub
parent e37e86b3b8
commit 479c8b85d3
29 changed files with 5436 additions and 0 deletions

View File

@@ -15,9 +15,12 @@
"""
from .cutlass_scaled_mm import cutlass_scaled_mm
from .machete_mm import machete_quantize_and_pack, machete_wint_mm
from .scaled_fp8_quant import scaled_fp8_quant
__all__ = [
"cutlass_scaled_mm",
"scaled_fp8_quant",
"machete_wint_mm",
"machete_quantize_and_pack",
]

View File

@@ -0,0 +1,185 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import paddle
from fastdeploy.platforms import current_platform
def get_sm_version():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return cc
if current_platform.is_cuda() and get_sm_version() == 90:
from fastdeploy.model_executor.ops.gpu import machete_mm, machete_prepack_B
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_rows(
q_w: paddle.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == [size_k, size_n]
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.place
q_w_np = q_w.numpy().astype(np.uint32)
q_res = np.zeros((size_k // pack_factor, size_n), dtype=np.uint32)
for i in range(pack_factor):
q_res |= q_w_np[i::pack_factor, :] << num_bits * i
q_res = paddle.to_tensor(q_res.astype(np.int32), place=orig_device)
return q_res
def quantize_weights(
w: paddle.Tensor,
group_size: Optional[int],
quant_type: str = "uint4b8",
):
"""
Quantize weights in PaddlePaddle, similar to PyTorch implementation.
Args:
w: Input weight tensor (must be float type).
quant_type: Target quantization type (e.g., `uint4`, `uint4b8`).
group_size: Group size for quantization. If `-1`, use channel-wise quantization.
zero_points: Whether to compute zero points (only for unsigned quant types).
ref_zero_points_after_scales: If True, apply zero points after scales in dequantization.
Returns:
w_ref: Dequantized reference weights.
w_q: Quantized weights.
w_s: Scales (None if `group_size` is None).
"""
assert paddle.is_floating_point(w), "w must be float type"
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
orig_device = w.place
size_k, size_n = w.shape
if group_size == -1:
group_size = size_k
# Reshape to [group_size, -1]
if group_size is not None and group_size < size_k:
w = w.reshape([-1, group_size, size_n])
w = w.transpose([1, 0, 2])
w = w.reshape([group_size, -1])
# Compute scale for each group
max_val = paddle.max(w, axis=0, keepdim=True)
min_val = paddle.min(w, axis=0, keepdim=True)
max_q_val = float(7.0)
min_q_val = float(-8.0)
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
if group_size is not None:
# Avoid division by zero
max_scale = paddle.maximum(
paddle.abs(max_val / (max_q_val if max_q_val != 0 else float("inf"))),
paddle.abs(min_val / (min_q_val if min_q_val != 0 else float("inf"))),
)
w_s = max_scale
# Quantize
w_q = paddle.round(w / w_s).astype(paddle.int32)
w_q = paddle.clip(w_q, min_q_val, max_q_val)
# if hasattr(quant_type, 'bias'): # Custom quantization bias (if applicable)
# w_q += quant_type.bias
if quant_type == "uint4b8":
w_q += 8
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w_tensor):
w_tensor = w_tensor.reshape([group_size, -1, size_n])
w_tensor = w_tensor.transpose([1, 0, 2])
w_tensor = w_tensor.reshape([size_k, size_n])
return w_tensor
w_q = reshape_w(w_q)
w_s = w_s.reshape([-1, size_n])
# Move tensors back to original device
w_q = w_q.to(orig_device)
if w_s is not None:
w_s = w_s.to(orig_device)
return w_q, w_s
def machete_quantize_and_pack(
w: paddle.Tensor,
atype: paddle.dtype,
quant_type: str = "uint4b8",
scale_type: str = "",
group_size: int = -1,
):
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
w_q = pack_rows(w_q, 4, *w_q.shape)
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
w_q_prepack = machete_prepack_B(
w_q_col,
atype,
quant_type,
scale_type,
)[0]
return w_q_prepack, w_s
def machete_wint_mm(
x: paddle.Tensor,
w_prepack: paddle.Tensor,
w_g_s: paddle.Tensor,
w_g_zp: Optional[paddle.Tensor] = None,
w_ch_s: Optional[paddle.Tensor] = None,
w_tok_s: Optional[paddle.Tensor] = None,
weight_dtype: str = "uint4b8",
group_size: int = -1,
out_dtype: str = "",
scheduler: str = "",
):
out = machete_mm(
x,
w_prepack,
w_g_s, # group scales
w_g_zp, # group zeros
w_ch_s, # per-channel scale
w_tok_s, # per-token scale
weight_dtype, # weight_dtype
out_dtype, # out_dtype
group_size, # group_size
scheduler, # scheduler
)[0]
return out

View File

@@ -21,6 +21,7 @@ from typing import Optional
import paddle
from paddle.nn.quant import weight_only_linear, weight_quantize
from fastdeploy import envs
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -33,6 +34,12 @@ from ..utils import get_tensor
from .quant_base import QuantConfigBase, QuantMethodBase
def get_sm_version():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return cc
class WeightOnlyConfig(QuantConfigBase):
"""
Quantization config for weight only
@@ -132,6 +139,14 @@ class WeightOnlyConfig(QuantConfigBase):
else:
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
else:
if (
self.name() == "wint4"
and envs.FD_USE_MACHETE == "1"
and get_sm_version() == 90
and layer.weight_shape[1]
and layer.weight_shape[1] % 128 == 0
):
return MacheteWeightOnlyLinearMethod(self)
return GPUWeightOnlyLinearMethod(self)
@@ -329,3 +344,73 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0])
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
Weight only quantization method for linear layer on GPU using Machete
The weights are loaded in the BF16 numerical format. After loading, the quantization coefficients will be computed,
and the weights will be quantized to int8 or int4.
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
def create_weights(self, layer, **extra_weight_attrs):
assert layer.bias is None, "Machete weight only linear method does not support bias."
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [1, layer.weight_shape[1]]
# layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 8
layer.weight_dtype = "int32"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
def process_prequanted_weights(self, layer, state_dict) -> None:
pass
def process_loaded_weights(self, layer, weight) -> None:
from fastdeploy.model_executor.layers.quantization.ops import (
machete_quantize_and_pack,
)
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=weight,
atype=layer._dtype,
quant_type="uint4b8",
)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
def apply(self, layer, x):
assert layer.bias is None, "Machete weight only linear method does not support bias."
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm
linear_out = machete_wint_mm(
x,
w_prepack=layer.weight,
w_g_s=layer.weight_scale,
weight_dtype="uint4b8",
)
return linear_out