This commit is contained in:
bukejiyu
2025-08-06 14:45:27 +08:00
committed by GitHub
parent 91dc87f1c5
commit 20839abccf
30 changed files with 1361 additions and 1087 deletions

View File

@@ -37,7 +37,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
self.quant_config = quant_config
self.group_size = -1
def create_weights(self, layer):
def create_weights(self, layer, **extra_weight_attrs):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]]
@@ -45,6 +45,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,