mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Optimize] Support WINT8 and group scale for Machete (#3905)
This commit is contained in:
@@ -142,11 +142,11 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
)
|
||||
|
||||
if (
|
||||
self.name() == "wint4"
|
||||
and _ENABLE_MACHETE
|
||||
_ENABLE_MACHETE
|
||||
and envs.FD_USE_MACHETE == "1"
|
||||
and layer.weight_shape[1]
|
||||
and layer.weight_shape[1] % 128 == 0
|
||||
and not layer.add_bias
|
||||
):
|
||||
return MacheteWeightOnlyLinearMethod(self)
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
@@ -230,6 +230,8 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
weight_scale_shape = [1, layer.weight_shape[1]]
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.weight_shape[0] //= 8
|
||||
else:
|
||||
layer.weight_shape[0] //= 4
|
||||
layer.weight_dtype = "int32"
|
||||
else:
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
@@ -282,7 +284,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
|
||||
w=layer.weight,
|
||||
atype=layer._dtype,
|
||||
quant_type="uint4b8",
|
||||
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
|
||||
)
|
||||
else:
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||
@@ -387,7 +389,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
|
||||
w=weight,
|
||||
atype=layer._dtype,
|
||||
quant_type="uint4b8",
|
||||
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
|
||||
)
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
@@ -400,7 +402,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
x,
|
||||
w_prepack=layer.weight,
|
||||
w_g_s=layer.weight_scale,
|
||||
weight_dtype="uint4b8",
|
||||
weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
|
||||
)
|
||||
|
||||
return linear_out
|
||||
|
Reference in New Issue
Block a user