[Optimize] Support WINT8 and group scale for Machete (#3905)

This commit is contained in:
Sunny-bot1
2025-09-15 12:01:34 +08:00
committed by GitHub
parent 4408dc7f67
commit b1a5b756a3
5 changed files with 125 additions and 42 deletions

View File

@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
std::optional<paddle::Tensor> const& maybe_token_scales,
std::string maybe_schedule) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<int64_t> maybe_group_size_opt;
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
std::optional<std::string> maybe_schedule_opt;
if (maybe_schedule == "") {
maybe_schedule_opt = std::nullopt;
} else {
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
}
return machete::mm_dispatch({.A = A,
.B = B,
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
paddle::DataType maybe_out_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}

View File

@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}

View File

@@ -85,7 +85,7 @@ def quantize_weights(
w_s: Scales (None if `group_size` is None).
"""
assert paddle.is_floating_point(w), "w must be float type"
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
assert quant_type in ["uint4b8", "uint8b128"], "only support quant_type = uint4b8, uint8b128"
orig_device = w.place
size_k, size_n = w.shape
@@ -103,8 +103,12 @@ def quantize_weights(
max_val = paddle.max(w, axis=0, keepdim=True)
min_val = paddle.min(w, axis=0, keepdim=True)
max_q_val = float(7.0)
min_q_val = float(-8.0)
if quant_type == "uint4b8":
max_q_val = float(7.0)
min_q_val = float(-8.0)
else:
max_q_val = float(127.0)
min_q_val = float(-128.0)
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
@@ -124,6 +128,8 @@ def quantize_weights(
# w_q += quant_type.bias
if quant_type == "uint4b8":
w_q += 8
else:
w_q += 128
# Restore original shapes
if group_size is not None and group_size < size_k:
@@ -131,11 +137,11 @@ def quantize_weights(
def reshape_w(w_tensor):
w_tensor = w_tensor.reshape([group_size, -1, size_n])
w_tensor = w_tensor.transpose([1, 0, 2])
w_tensor = w_tensor.reshape([size_k, size_n])
w_tensor = w_tensor.reshape([size_k, size_n]).contiguous()
return w_tensor
w_q = reshape_w(w_q)
w_s = w_s.reshape([-1, size_n])
w_s = w_s.reshape([-1, size_n]).contiguous()
# Move tensors back to original device
w_q = w_q.to(orig_device)
@@ -153,7 +159,8 @@ def machete_quantize_and_pack(
group_size: int = -1,
):
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
w_q = pack_rows(w_q, 4, *w_q.shape)
num_bits = 4 if quant_type == "uint4b8" else 8
w_q = pack_rows(w_q, num_bits, *w_q.shape)
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
w_q_prepack = machete_prepack_B(
w_q_col,

View File

@@ -142,11 +142,11 @@ class WeightOnlyConfig(QuantConfigBase):
)
if (
self.name() == "wint4"
and _ENABLE_MACHETE
_ENABLE_MACHETE
and envs.FD_USE_MACHETE == "1"
and layer.weight_shape[1]
and layer.weight_shape[1] % 128 == 0
and not layer.add_bias
):
return MacheteWeightOnlyLinearMethod(self)
return GPUWeightOnlyLinearMethod(self)
@@ -230,6 +230,8 @@ class WeightOnlyLinearMethod(QuantMethodBase):
weight_scale_shape = [1, layer.weight_shape[1]]
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 8
else:
layer.weight_shape[0] //= 4
layer.weight_dtype = "int32"
else:
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
@@ -282,7 +284,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=layer.weight,
atype=layer._dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
else:
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
@@ -387,7 +389,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=weight,
atype=layer._dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
@@ -400,7 +402,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
x,
w_prepack=layer.weight,
w_g_s=layer.weight_scale,
weight_dtype="uint4b8",
weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
return linear_out

View File

@@ -64,11 +64,11 @@ def convert_uint16_to_float(in_list):
not core.is_compiled_with_cuda() or get_sm_version() < 90,
"machete only support sm90.",
)
class WeightOnlyLinearTestCase(unittest.TestCase):
class WeightOnlyInt4LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = "float16"
self.rtol = 1e-5
self.atol = 1e-2
self.atol = 1.3e-1
self.bias = False
self.batch = 1
self.token = 512
@@ -77,11 +77,10 @@ class WeightOnlyLinearTestCase(unittest.TestCase):
self.weight_dtype = "int4"
self.static = False
self.group_size = -1
self.machete_group_size = -1
def setUp(self):
self.config()
if self.dtype == "bfloat16" or self.weight_dtype == "int4":
self.atol = 1.3e-1
x = np.random.random((self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
@@ -111,29 +110,30 @@ class WeightOnlyLinearTestCase(unittest.TestCase):
return out.numpy()
def get_weight_only_linear_out(self):
for i in range(10):
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
return out.numpy()
def get_machete_weight_only_linear_out(self):
w_q, w_s = machete_quantize_and_pack(
w=self.float_weight.cuda(),
atype=self.dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128",
group_size=self.machete_group_size,
)
out = machete_wint_mm(
self.x,
w_prepack=w_q,
w_g_s=w_s, # group scales
weight_dtype="uint4b8", # weight_dtype
weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype
group_size=self.machete_group_size,
)
return out.numpy()
@@ -149,26 +149,94 @@ class WeightOnlyLinearTestCase(unittest.TestCase):
np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol)
M = [32, 128]
K_N = [[2048, 4096]]
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_sm_version() < 90,
"machete only support sm90.",
)
class WeightOnlyInt8LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = "float16"
self.rtol = 1e-5
self.atol = 1e-1
self.bias = False
self.batch = 1
self.token = 512
self.in_features = 7168
self.out_features = 1024
self.weight_dtype = "int8"
self.static = False
self.group_size = -1
self.machete_group_size = 128
def setUp(self):
self.config()
x = np.random.random((self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
bias_attr = base.ParamAttr(
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(self.in_features, self.out_features, bias_attr=bias_attr)
def make_case(m, k, n):
class Case(WeightOnlyLinearTestCase):
def config(self, _m=m, _k=k, _n=n):
super().config()
self.token = m
self.in_features = k
self.out_features = n
self.bias = self.linear.bias
self.weight = self.linear.weight
self.float_weight = self.linear.weight
self.weight_scale = None
Case.name = f"WeightOnlyLinearTestCase{m}{k}{n}"
return Case
self.weight, self.weight_scale = Q.weight_quantize(
(self.float_weight.cuda() if self.weight_dtype == "int8" else self.weight.cpu()),
algo=("weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4"),
group_size=self.group_size,
)
def get_linear_out(self):
out = self.linear(self.x)
return out.numpy()
def get_weight_only_linear_out(self):
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
return out.numpy()
def get_machete_weight_only_linear_out(self):
w_q, w_s = machete_quantize_and_pack(
w=self.float_weight.cuda(),
atype=self.dtype,
quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128",
group_size=self.machete_group_size,
)
out = machete_wint_mm(
self.x,
w_prepack=w_q,
w_g_s=w_s, # group scales
weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype
group_size=self.machete_group_size,
)
return out.numpy()
def test_weight_only_linear(self):
out_expect = self.get_linear_out()
# out_paddle = self.get_weight_only_linear_out()
out_machete = self.get_machete_weight_only_linear_out()
if self.dtype == "bfloat16":
# out_paddle = convert_uint16_to_float(out_paddle)
out_expect = convert_uint16_to_float(out_expect)
out_machete = convert_uint16_to_float(out_machete)
np.testing.assert_allclose(out_expect, out_machete, rtol=self.rtol, atol=self.atol)
for k, n in K_N:
for m in M:
cls = make_case(m, k, n)
globals()[cls.name] = cls
if __name__ == "__main__":
unittest.main()