mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[CP2.2] Machete support group scale & wint8 & v1 loader (#4166)
* support v1 loader for machete (#3999) * [Optimize] Support WINT8 and group scale for Machete (#3905) * [Optimize] Machete using group scale default (#4121)
This commit is contained in:
@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
|||||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||||
std::string maybe_schedule) {
|
std::string maybe_schedule) {
|
||||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||||
std::optional<int64_t> maybe_group_size_opt;
|
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||||
std::optional<std::string> maybe_schedule_opt;
|
std::optional<std::string> maybe_schedule_opt;
|
||||||
if (maybe_schedule == "") {
|
if (maybe_schedule == "") {
|
||||||
maybe_schedule_opt = std::nullopt;
|
maybe_schedule_opt = std::nullopt;
|
||||||
|
} else {
|
||||||
|
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||||
}
|
}
|
||||||
return machete::mm_dispatch({.A = A,
|
return machete::mm_dispatch({.A = A,
|
||||||
.B = B,
|
.B = B,
|
||||||
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
|||||||
paddle::DataType maybe_out_type;
|
paddle::DataType maybe_out_type;
|
||||||
if (b_type_str == "uint4b8") {
|
if (b_type_str == "uint4b8") {
|
||||||
b_type_id = machete::kU4B8.id();
|
b_type_id = machete::kU4B8.id();
|
||||||
|
} else if (b_type_str == "uint8b128") {
|
||||||
|
b_type_id = machete::kU8B128.id();
|
||||||
} else {
|
} else {
|
||||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||||
}
|
}
|
||||||
|
@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
|||||||
|
|
||||||
if (b_type_str == "uint4b8") {
|
if (b_type_str == "uint4b8") {
|
||||||
b_type_id = machete::kU4B8.id();
|
b_type_id = machete::kU4B8.id();
|
||||||
|
} else if (b_type_str == "uint8b128") {
|
||||||
|
b_type_id = machete::kU8B128.id();
|
||||||
} else {
|
} else {
|
||||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||||
}
|
}
|
||||||
|
@@ -85,7 +85,7 @@ def quantize_weights(
|
|||||||
w_s: Scales (None if `group_size` is None).
|
w_s: Scales (None if `group_size` is None).
|
||||||
"""
|
"""
|
||||||
assert paddle.is_floating_point(w), "w must be float type"
|
assert paddle.is_floating_point(w), "w must be float type"
|
||||||
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
|
assert quant_type in ["uint4b8", "uint8b128"], "only support quant_type = uint4b8, uint8b128"
|
||||||
|
|
||||||
orig_device = w.place
|
orig_device = w.place
|
||||||
size_k, size_n = w.shape
|
size_k, size_n = w.shape
|
||||||
@@ -103,8 +103,12 @@ def quantize_weights(
|
|||||||
max_val = paddle.max(w, axis=0, keepdim=True)
|
max_val = paddle.max(w, axis=0, keepdim=True)
|
||||||
min_val = paddle.min(w, axis=0, keepdim=True)
|
min_val = paddle.min(w, axis=0, keepdim=True)
|
||||||
|
|
||||||
max_q_val = float(7.0)
|
if quant_type == "uint4b8":
|
||||||
min_q_val = float(-8.0)
|
max_q_val = float(7.0)
|
||||||
|
min_q_val = float(-8.0)
|
||||||
|
else:
|
||||||
|
max_q_val = float(127.0)
|
||||||
|
min_q_val = float(-128.0)
|
||||||
|
|
||||||
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
|
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
|
||||||
|
|
||||||
@@ -124,6 +128,8 @@ def quantize_weights(
|
|||||||
# w_q += quant_type.bias
|
# w_q += quant_type.bias
|
||||||
if quant_type == "uint4b8":
|
if quant_type == "uint4b8":
|
||||||
w_q += 8
|
w_q += 8
|
||||||
|
else:
|
||||||
|
w_q += 128
|
||||||
|
|
||||||
# Restore original shapes
|
# Restore original shapes
|
||||||
if group_size is not None and group_size < size_k:
|
if group_size is not None and group_size < size_k:
|
||||||
@@ -131,11 +137,11 @@ def quantize_weights(
|
|||||||
def reshape_w(w_tensor):
|
def reshape_w(w_tensor):
|
||||||
w_tensor = w_tensor.reshape([group_size, -1, size_n])
|
w_tensor = w_tensor.reshape([group_size, -1, size_n])
|
||||||
w_tensor = w_tensor.transpose([1, 0, 2])
|
w_tensor = w_tensor.transpose([1, 0, 2])
|
||||||
w_tensor = w_tensor.reshape([size_k, size_n])
|
w_tensor = w_tensor.reshape([size_k, size_n]).contiguous()
|
||||||
return w_tensor
|
return w_tensor
|
||||||
|
|
||||||
w_q = reshape_w(w_q)
|
w_q = reshape_w(w_q)
|
||||||
w_s = w_s.reshape([-1, size_n])
|
w_s = w_s.reshape([-1, size_n]).contiguous()
|
||||||
|
|
||||||
# Move tensors back to original device
|
# Move tensors back to original device
|
||||||
w_q = w_q.to(orig_device)
|
w_q = w_q.to(orig_device)
|
||||||
@@ -153,7 +159,8 @@ def machete_quantize_and_pack(
|
|||||||
group_size: int = -1,
|
group_size: int = -1,
|
||||||
):
|
):
|
||||||
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
|
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
|
||||||
w_q = pack_rows(w_q, 4, *w_q.shape)
|
num_bits = 4 if quant_type == "uint4b8" else 8
|
||||||
|
w_q = pack_rows(w_q, num_bits, *w_q.shape)
|
||||||
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
|
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
|
||||||
w_q_prepack = machete_prepack_B(
|
w_q_prepack = machete_prepack_B(
|
||||||
w_q_col,
|
w_q_col,
|
||||||
|
@@ -141,8 +141,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.name() == "wint4"
|
_ENABLE_MACHETE
|
||||||
and _ENABLE_MACHETE
|
|
||||||
and envs.FD_USE_MACHETE == "1"
|
and envs.FD_USE_MACHETE == "1"
|
||||||
and layer.weight_shape[1]
|
and layer.weight_shape[1]
|
||||||
and layer.weight_shape[1] % 128 == 0
|
and layer.weight_shape[1] % 128 == 0
|
||||||
@@ -219,12 +218,22 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
quant_attrs,
|
quant_attrs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
if isinstance(self, MacheteWeightOnlyLinearMethod):
|
||||||
weight_scale_shape = [layer.weight_shape[1]]
|
# Using group scale for machete, group size is 128
|
||||||
layer.weight_shape.reverse()
|
weight_scale_shape = [(layer.weight_shape[0] + 127) // 128, layer.weight_shape[1]]
|
||||||
if self.quant_config.name() == "wint4":
|
if self.quant_config.name() == "wint4":
|
||||||
layer.weight_shape[0] //= 2
|
layer.weight_shape[0] //= 8
|
||||||
layer.weight_dtype = "int8"
|
else:
|
||||||
|
layer.weight_shape[0] //= 4
|
||||||
|
layer.weight_dtype = "int32"
|
||||||
|
else:
|
||||||
|
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||||
|
weight_scale_shape = [layer.weight_shape[1]]
|
||||||
|
layer.weight_shape.reverse()
|
||||||
|
if self.quant_config.name() == "wint4":
|
||||||
|
layer.weight_shape[0] //= 2
|
||||||
|
layer.weight_dtype = "int8"
|
||||||
|
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=layer.weight_shape,
|
shape=layer.weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -260,17 +269,30 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
if not layer.fd_config.load_config.load_choices == "default_v1":
|
if not layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
return
|
return
|
||||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
if isinstance(self, MacheteWeightOnlyLinearMethod):
|
||||||
layer.weight,
|
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||||
algo=self.quant_config.algo,
|
machete_quantize_and_pack,
|
||||||
arch=self.quant_config.weight_only_linear_arch,
|
)
|
||||||
)
|
|
||||||
|
# Using group scale for machete, group size is 128
|
||||||
|
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
|
||||||
|
w=layer.weight,
|
||||||
|
atype=layer._dtype,
|
||||||
|
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
|
||||||
|
group_size=128,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||||
|
layer.weight,
|
||||||
|
algo=self.quant_config.algo,
|
||||||
|
arch=self.quant_config.weight_only_linear_arch,
|
||||||
|
)
|
||||||
|
|
||||||
free_tensor(layer.weight)
|
free_tensor(layer.weight)
|
||||||
|
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=quanted_weight_tensor.shape,
|
shape=quanted_weight_tensor.shape,
|
||||||
dtype="int8",
|
dtype="int8" if not isinstance(self, MacheteWeightOnlyLinearMethod) else "int32",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
@@ -361,32 +383,6 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
def create_weights(self, layer, **extra_weight_attrs):
|
|
||||||
|
|
||||||
assert layer.bias is None, "Machete weight only linear method does not support bias."
|
|
||||||
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
|
|
||||||
|
|
||||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
|
||||||
weight_scale_shape = [1, layer.weight_shape[1]]
|
|
||||||
|
|
||||||
# layer.weight_shape.reverse()
|
|
||||||
if self.quant_config.name() == "wint4":
|
|
||||||
layer.weight_shape[0] //= 8
|
|
||||||
layer.weight_dtype = "int32"
|
|
||||||
|
|
||||||
layer.weight = layer.create_parameter(
|
|
||||||
shape=layer.weight_shape,
|
|
||||||
dtype=layer.weight_dtype,
|
|
||||||
is_bias=False,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.weight_scale = layer.create_parameter(
|
|
||||||
shape=weight_scale_shape,
|
|
||||||
dtype=layer._dtype,
|
|
||||||
is_bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -395,24 +391,27 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
machete_quantize_and_pack,
|
machete_quantize_and_pack,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Using group scale for machete, group size is 128
|
||||||
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
|
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
|
||||||
w=weight,
|
w=weight,
|
||||||
atype=layer._dtype,
|
atype=layer._dtype,
|
||||||
quant_type="uint4b8",
|
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
|
||||||
|
group_size=128,
|
||||||
)
|
)
|
||||||
layer.weight.set_value(quanted_weight_tensor)
|
layer.weight.set_value(quanted_weight_tensor)
|
||||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||||
|
|
||||||
def apply(self, layer, x):
|
def apply(self, layer, x):
|
||||||
assert layer.bias is None, "Machete weight only linear method does not support bias."
|
|
||||||
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
|
|
||||||
from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm
|
from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm
|
||||||
|
|
||||||
|
# Using group scale for machete, group size is 128
|
||||||
linear_out = machete_wint_mm(
|
linear_out = machete_wint_mm(
|
||||||
x,
|
x,
|
||||||
w_prepack=layer.weight,
|
w_prepack=layer.weight,
|
||||||
w_g_s=layer.weight_scale,
|
w_g_s=layer.weight_scale,
|
||||||
weight_dtype="uint4b8",
|
weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
|
||||||
|
group_size=128,
|
||||||
)
|
)
|
||||||
|
if layer.with_bias:
|
||||||
|
linear_out = paddle.add(linear_out, layer.bias)
|
||||||
return linear_out
|
return linear_out
|
||||||
|
@@ -64,11 +64,11 @@ def convert_uint16_to_float(in_list):
|
|||||||
not core.is_compiled_with_cuda() or get_sm_version() < 90,
|
not core.is_compiled_with_cuda() or get_sm_version() < 90,
|
||||||
"machete only support sm90.",
|
"machete only support sm90.",
|
||||||
)
|
)
|
||||||
class WeightOnlyLinearTestCase(unittest.TestCase):
|
class WeightOnlyInt4LinearTestCase(unittest.TestCase):
|
||||||
def config(self):
|
def config(self):
|
||||||
self.dtype = "float16"
|
self.dtype = "float16"
|
||||||
self.rtol = 1e-5
|
self.rtol = 1e-5
|
||||||
self.atol = 1e-2
|
self.atol = 1.3e-1
|
||||||
self.bias = False
|
self.bias = False
|
||||||
self.batch = 1
|
self.batch = 1
|
||||||
self.token = 512
|
self.token = 512
|
||||||
@@ -77,11 +77,10 @@ class WeightOnlyLinearTestCase(unittest.TestCase):
|
|||||||
self.weight_dtype = "int4"
|
self.weight_dtype = "int4"
|
||||||
self.static = False
|
self.static = False
|
||||||
self.group_size = -1
|
self.group_size = -1
|
||||||
|
self.machete_group_size = -1
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.config()
|
self.config()
|
||||||
if self.dtype == "bfloat16" or self.weight_dtype == "int4":
|
|
||||||
self.atol = 1.3e-1
|
|
||||||
x = np.random.random((self.token, self.in_features))
|
x = np.random.random((self.token, self.in_features))
|
||||||
self.x = paddle.to_tensor(x, dtype=self.dtype)
|
self.x = paddle.to_tensor(x, dtype=self.dtype)
|
||||||
if self.bias:
|
if self.bias:
|
||||||
@@ -111,30 +110,33 @@ class WeightOnlyLinearTestCase(unittest.TestCase):
|
|||||||
return out.numpy()
|
return out.numpy()
|
||||||
|
|
||||||
def get_weight_only_linear_out(self):
|
def get_weight_only_linear_out(self):
|
||||||
for i in range(10):
|
out = Q.weight_only_linear(
|
||||||
out = Q.weight_only_linear(
|
self.x,
|
||||||
self.x,
|
self.weight,
|
||||||
self.weight,
|
bias=self.bias,
|
||||||
bias=self.bias,
|
weight_scale=self.weight_scale,
|
||||||
weight_scale=self.weight_scale,
|
weight_dtype=self.weight_dtype,
|
||||||
weight_dtype=self.weight_dtype,
|
group_size=self.group_size,
|
||||||
group_size=self.group_size,
|
)
|
||||||
)
|
|
||||||
return out.numpy()
|
return out.numpy()
|
||||||
|
|
||||||
def get_machete_weight_only_linear_out(self):
|
def get_machete_weight_only_linear_out(self):
|
||||||
w_q, w_s = machete_quantize_and_pack(
|
w_q, w_s = machete_quantize_and_pack(
|
||||||
w=self.float_weight.cuda(),
|
w=self.float_weight.cuda(),
|
||||||
atype=self.dtype,
|
atype=self.dtype,
|
||||||
quant_type="uint4b8",
|
quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128",
|
||||||
|
group_size=self.machete_group_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = machete_wint_mm(
|
out = machete_wint_mm(
|
||||||
self.x,
|
self.x,
|
||||||
w_prepack=w_q,
|
w_prepack=w_q,
|
||||||
w_g_s=w_s, # group scales
|
w_g_s=w_s, # group scales
|
||||||
weight_dtype="uint4b8", # weight_dtype
|
weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype
|
||||||
|
group_size=self.machete_group_size,
|
||||||
)
|
)
|
||||||
|
if self.bias is not None:
|
||||||
|
out = paddle.add(out, self.bias)
|
||||||
return out.numpy()
|
return out.numpy()
|
||||||
|
|
||||||
def test_weight_only_linear(self):
|
def test_weight_only_linear(self):
|
||||||
@@ -149,26 +151,96 @@ class WeightOnlyLinearTestCase(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol)
|
np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol)
|
||||||
|
|
||||||
|
|
||||||
M = [32, 128]
|
@unittest.skipIf(
|
||||||
K_N = [[2048, 4096]]
|
not core.is_compiled_with_cuda() or get_sm_version() < 90,
|
||||||
|
"machete only support sm90.",
|
||||||
|
)
|
||||||
|
class WeightOnlyInt8LinearTestCase(unittest.TestCase):
|
||||||
|
def config(self):
|
||||||
|
self.dtype = "float16"
|
||||||
|
self.rtol = 1e-5
|
||||||
|
self.atol = 1e-1
|
||||||
|
self.bias = True
|
||||||
|
self.batch = 1
|
||||||
|
self.token = 512
|
||||||
|
self.in_features = 7168
|
||||||
|
self.out_features = 1024
|
||||||
|
self.weight_dtype = "int8"
|
||||||
|
self.static = False
|
||||||
|
self.group_size = -1
|
||||||
|
self.machete_group_size = 128
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.config()
|
||||||
|
x = np.random.random((self.token, self.in_features))
|
||||||
|
self.x = paddle.to_tensor(x, dtype=self.dtype)
|
||||||
|
if self.bias:
|
||||||
|
bias_attr = base.ParamAttr(
|
||||||
|
trainable=False,
|
||||||
|
regularizer=None,
|
||||||
|
initializer=paddle.nn.initializer.Constant(value=1.0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_attr = None
|
||||||
|
set_default_dtype(self.dtype)
|
||||||
|
self.linear = paddle.nn.Linear(self.in_features, self.out_features, bias_attr=bias_attr)
|
||||||
|
|
||||||
def make_case(m, k, n):
|
self.bias = self.linear.bias
|
||||||
class Case(WeightOnlyLinearTestCase):
|
self.weight = self.linear.weight
|
||||||
def config(self, _m=m, _k=k, _n=n):
|
self.float_weight = self.linear.weight
|
||||||
super().config()
|
self.weight_scale = None
|
||||||
self.token = m
|
|
||||||
self.in_features = k
|
|
||||||
self.out_features = n
|
|
||||||
|
|
||||||
Case.name = f"WeightOnlyLinearTestCase{m}{k}{n}"
|
self.weight, self.weight_scale = Q.weight_quantize(
|
||||||
return Case
|
(self.float_weight.cuda() if self.weight_dtype == "int8" else self.weight.cpu()),
|
||||||
|
algo=("weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4"),
|
||||||
|
group_size=self.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_linear_out(self):
|
||||||
|
out = self.linear(self.x)
|
||||||
|
return out.numpy()
|
||||||
|
|
||||||
|
def get_weight_only_linear_out(self):
|
||||||
|
out = Q.weight_only_linear(
|
||||||
|
self.x,
|
||||||
|
self.weight,
|
||||||
|
bias=self.bias,
|
||||||
|
weight_scale=self.weight_scale,
|
||||||
|
weight_dtype=self.weight_dtype,
|
||||||
|
group_size=self.group_size,
|
||||||
|
)
|
||||||
|
return out.numpy()
|
||||||
|
|
||||||
|
def get_machete_weight_only_linear_out(self):
|
||||||
|
w_q, w_s = machete_quantize_and_pack(
|
||||||
|
w=self.float_weight.cuda(),
|
||||||
|
atype=self.dtype,
|
||||||
|
quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128",
|
||||||
|
group_size=self.machete_group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = machete_wint_mm(
|
||||||
|
self.x,
|
||||||
|
w_prepack=w_q,
|
||||||
|
w_g_s=w_s, # group scales
|
||||||
|
weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype
|
||||||
|
group_size=self.machete_group_size,
|
||||||
|
)
|
||||||
|
if self.bias is not None:
|
||||||
|
out = paddle.add(out, self.bias)
|
||||||
|
return out.numpy()
|
||||||
|
|
||||||
|
def test_weight_only_linear(self):
|
||||||
|
out_expect = self.get_linear_out()
|
||||||
|
# out_paddle = self.get_weight_only_linear_out()
|
||||||
|
out_machete = self.get_machete_weight_only_linear_out()
|
||||||
|
|
||||||
|
if self.dtype == "bfloat16":
|
||||||
|
# out_paddle = convert_uint16_to_float(out_paddle)
|
||||||
|
out_expect = convert_uint16_to_float(out_expect)
|
||||||
|
out_machete = convert_uint16_to_float(out_machete)
|
||||||
|
np.testing.assert_allclose(out_expect, out_machete, rtol=self.rtol, atol=self.atol)
|
||||||
|
|
||||||
for k, n in K_N:
|
|
||||||
for m in M:
|
|
||||||
cls = make_case(m, k, n)
|
|
||||||
globals()[cls.name] = cls
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Reference in New Issue
Block a user