[Optimize] Support WINT8 and group scale for Machete (#3905)

This commit is contained in:
Sunny-bot1
2025-09-15 12:01:34 +08:00
committed by GitHub
parent 4408dc7f67
commit b1a5b756a3
5 changed files with 125 additions and 42 deletions

View File

@@ -85,7 +85,7 @@ def quantize_weights(
w_s: Scales (None if `group_size` is None).
"""
assert paddle.is_floating_point(w), "w must be float type"
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
assert quant_type in ["uint4b8", "uint8b128"], "only support quant_type = uint4b8, uint8b128"
orig_device = w.place
size_k, size_n = w.shape
@@ -103,8 +103,12 @@ def quantize_weights(
max_val = paddle.max(w, axis=0, keepdim=True)
min_val = paddle.min(w, axis=0, keepdim=True)
max_q_val = float(7.0)
min_q_val = float(-8.0)
if quant_type == "uint4b8":
max_q_val = float(7.0)
min_q_val = float(-8.0)
else:
max_q_val = float(127.0)
min_q_val = float(-128.0)
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
@@ -124,6 +128,8 @@ def quantize_weights(
# w_q += quant_type.bias
if quant_type == "uint4b8":
w_q += 8
else:
w_q += 128
# Restore original shapes
if group_size is not None and group_size < size_k:
@@ -131,11 +137,11 @@ def quantize_weights(
def reshape_w(w_tensor):
w_tensor = w_tensor.reshape([group_size, -1, size_n])
w_tensor = w_tensor.transpose([1, 0, 2])
w_tensor = w_tensor.reshape([size_k, size_n])
w_tensor = w_tensor.reshape([size_k, size_n]).contiguous()
return w_tensor
w_q = reshape_w(w_q)
w_s = w_s.reshape([-1, size_n])
w_s = w_s.reshape([-1, size_n]).contiguous()
# Move tensors back to original device
w_q = w_q.to(orig_device)
@@ -153,7 +159,8 @@ def machete_quantize_and_pack(
group_size: int = -1,
):
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
w_q = pack_rows(w_q, 4, *w_q.shape)
num_bits = 4 if quant_type == "uint4b8" else 8
w_q = pack_rows(w_q, num_bits, *w_q.shape)
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
w_q_prepack = machete_prepack_B(
w_q_col,