mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
fix machete pre quant (#4295)
This commit is contained in:
@@ -129,6 +129,7 @@ class LinearBase(nn.Layer):
|
|||||||
self.with_bias = with_bias
|
self.with_bias = with_bias
|
||||||
self.add_bias = add_bias
|
self.add_bias = add_bias
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
self.is_quantized = fd_config.model_config.is_quantized
|
||||||
# key
|
# key
|
||||||
if weight_key:
|
if weight_key:
|
||||||
self.weight_key = f"{prefix}.{weight_key}"
|
self.weight_key = f"{prefix}.{weight_key}"
|
||||||
|
@@ -20,6 +20,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn.quant import weight_quantize
|
from paddle.nn.quant import weight_quantize
|
||||||
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.model_executor.layers.linear import (
|
from fastdeploy.model_executor.layers.linear import (
|
||||||
@@ -159,9 +160,11 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
if (
|
if (
|
||||||
_ENABLE_MACHETE
|
_ENABLE_MACHETE
|
||||||
and envs.FD_USE_MACHETE == "1"
|
and envs.FD_USE_MACHETE == "1"
|
||||||
|
and not layer.is_quantized
|
||||||
and layer.weight_shape[1]
|
and layer.weight_shape[1]
|
||||||
and layer.weight_shape[1] % 128 == 0
|
and layer.weight_shape[1] % 128 == 0
|
||||||
):
|
):
|
||||||
|
logger.info("Using Machete kernel for WeightOnlyLinearMethod")
|
||||||
return MacheteWeightOnlyLinearMethod(self)
|
return MacheteWeightOnlyLinearMethod(self)
|
||||||
return GPUWeightOnlyLinearMethod(self)
|
return GPUWeightOnlyLinearMethod(self)
|
||||||
|
|
||||||
@@ -399,7 +402,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||||
pass
|
raise NotImplementedError("Machete kernel doesn't support prequant. Please set FD_USE_MACHETE to 0.")
|
||||||
|
|
||||||
def process_loaded_weights(self, layer, weight) -> None:
|
def process_loaded_weights(self, layer, weight) -> None:
|
||||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||||
|
Reference in New Issue
Block a user