[Optimize] Support WINT8 and group scale for Machete (#3905)

This commit is contained in:
Sunny-bot1
2025-09-15 12:01:34 +08:00
committed by GitHub
parent 4408dc7f67
commit b1a5b756a3
5 changed files with 125 additions and 42 deletions

View File

@@ -142,11 +142,11 @@ class WeightOnlyConfig(QuantConfigBase):
)
if (
self.name() == "wint4"
and _ENABLE_MACHETE
_ENABLE_MACHETE
and envs.FD_USE_MACHETE == "1"
and layer.weight_shape[1]
and layer.weight_shape[1] % 128 == 0
and not layer.add_bias
):
return MacheteWeightOnlyLinearMethod(self)
return GPUWeightOnlyLinearMethod(self)
@@ -230,6 +230,8 @@ class WeightOnlyLinearMethod(QuantMethodBase):
weight_scale_shape = [1, layer.weight_shape[1]]
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 8
else:
layer.weight_shape[0] //= 4
layer.weight_dtype = "int32"
else:
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
@@ -282,7 +284,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=layer.weight,
atype=layer._dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
else:
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
@@ -387,7 +389,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=weight,
atype=layer._dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
@@ -400,7 +402,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
x,
w_prepack=layer.weight,
w_g_s=layer.weight_scale,
weight_dtype="uint4b8",
weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
return linear_out