This commit is contained in:
bukejiyu
2025-08-06 14:45:27 +08:00
committed by GitHub
parent 91dc87f1c5
commit 20839abccf
30 changed files with 1361 additions and 1087 deletions

View File

@@ -74,7 +74,7 @@ class W8A8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
def create_weights(self, layer):
def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse()
layer.weight_dtype = "int8"
if self.quant_config.use_smooth_quant:
@@ -85,7 +85,12 @@ class W8A8LinearMethod(QuantMethodBase):
if weight_scale is None or in_scale is None:
self.skip_quant = True
return
layer.wieght = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
max_range = 127.0
linear_out_scale = paddle.to_tensor(weight_scale / (max_range * max_range * in_scale)).astype("float32")
layer.linear_out_scale = layer.create_parameter(
@@ -136,7 +141,7 @@ class SmoothQuantLinearMethod(QuantMethodBase):
super().__init__()
self.quant_config = quant_config
def create_weights(self, layer):
def create_weights(self, layer, **extra_weight_attrs):
linear_shift_shape = [layer.output_size]
linear_smooth_shape = [layer.output_size]
layer.linear_shift = self.create_parameter(