diff --git a/custom_ops/gpu_ops/machete/machete_mm.cu b/custom_ops/gpu_ops/machete/machete_mm.cu index 53774fa0c..c6f56d1c9 100644 --- a/custom_ops/gpu_ops/machete/machete_mm.cu +++ b/custom_ops/gpu_ops/machete/machete_mm.cu @@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B, std::optional const& maybe_token_scales, std::string maybe_schedule) { machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id); - std::optional maybe_group_size_opt; + std::optional maybe_group_size_opt = std::optional(maybe_group_size); std::optional maybe_schedule_opt; if (maybe_schedule == "") { maybe_schedule_opt = std::nullopt; + } else { + maybe_schedule_opt = std::optional(maybe_schedule); } return machete::mm_dispatch({.A = A, .B = B, @@ -63,6 +65,8 @@ std::vector MacheteMMKernel( paddle::DataType maybe_out_type; if (b_type_str == "uint4b8") { b_type_id = machete::kU4B8.id(); + } else if (b_type_str == "uint8b128") { + b_type_id = machete::kU8B128.id(); } else { PADDLE_ENFORCE(false, "b_type_str not supported!"); } diff --git a/custom_ops/gpu_ops/machete/machete_prepack_B.cu b/custom_ops/gpu_ops/machete/machete_prepack_B.cu index 6014ca9ef..34bd1c705 100644 --- a/custom_ops/gpu_ops/machete/machete_prepack_B.cu +++ b/custom_ops/gpu_ops/machete/machete_prepack_B.cu @@ -51,6 +51,8 @@ std::vector MachetePrepackBKernel( if (b_type_str == "uint4b8") { b_type_id = machete::kU4B8.id(); + } else if (b_type_str == "uint8b128") { + b_type_id = machete::kU8B128.id(); } else { PADDLE_ENFORCE(false, "b_type_str not supported!"); } diff --git a/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py b/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py index 218da0d21..b080bb627 100644 --- a/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py +++ b/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py @@ -85,7 +85,7 @@ def quantize_weights( w_s: Scales (None if `group_size` is None). """ assert paddle.is_floating_point(w), "w must be float type" - assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8" + assert quant_type in ["uint4b8", "uint8b128"], "only support quant_type = uint4b8, uint8b128" orig_device = w.place size_k, size_n = w.shape @@ -103,8 +103,12 @@ def quantize_weights( max_val = paddle.max(w, axis=0, keepdim=True) min_val = paddle.min(w, axis=0, keepdim=True) - max_q_val = float(7.0) - min_q_val = float(-8.0) + if quant_type == "uint4b8": + max_q_val = float(7.0) + min_q_val = float(-8.0) + else: + max_q_val = float(127.0) + min_q_val = float(-128.0) w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case @@ -124,6 +128,8 @@ def quantize_weights( # w_q += quant_type.bias if quant_type == "uint4b8": w_q += 8 + else: + w_q += 128 # Restore original shapes if group_size is not None and group_size < size_k: @@ -131,11 +137,11 @@ def quantize_weights( def reshape_w(w_tensor): w_tensor = w_tensor.reshape([group_size, -1, size_n]) w_tensor = w_tensor.transpose([1, 0, 2]) - w_tensor = w_tensor.reshape([size_k, size_n]) + w_tensor = w_tensor.reshape([size_k, size_n]).contiguous() return w_tensor w_q = reshape_w(w_q) - w_s = w_s.reshape([-1, size_n]) + w_s = w_s.reshape([-1, size_n]).contiguous() # Move tensors back to original device w_q = w_q.to(orig_device) @@ -153,7 +159,8 @@ def machete_quantize_and_pack( group_size: int = -1, ): w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type) - w_q = pack_rows(w_q, 4, *w_q.shape) + num_bits = 4 if quant_type == "uint4b8" else 8 + w_q = pack_rows(w_q, num_bits, *w_q.shape) w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major w_q_prepack = machete_prepack_B( w_q_col, diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 5e4c8ed52..0d56491a1 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -142,11 +142,11 @@ class WeightOnlyConfig(QuantConfigBase): ) if ( - self.name() == "wint4" - and _ENABLE_MACHETE + _ENABLE_MACHETE and envs.FD_USE_MACHETE == "1" and layer.weight_shape[1] and layer.weight_shape[1] % 128 == 0 + and not layer.add_bias ): return MacheteWeightOnlyLinearMethod(self) return GPUWeightOnlyLinearMethod(self) @@ -230,6 +230,8 @@ class WeightOnlyLinearMethod(QuantMethodBase): weight_scale_shape = [1, layer.weight_shape[1]] if self.quant_config.name() == "wint4": layer.weight_shape[0] //= 8 + else: + layer.weight_shape[0] //= 4 layer.weight_dtype = "int32" else: # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. @@ -282,7 +284,7 @@ class WeightOnlyLinearMethod(QuantMethodBase): quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack( w=layer.weight, atype=layer._dtype, - quant_type="uint4b8", + quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128", ) else: quanted_weight_tensor, weight_scale_tensor = weight_quantize( @@ -387,7 +389,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod): quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack( w=weight, atype=layer._dtype, - quant_type="uint4b8", + quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128", ) layer.weight.set_value(quanted_weight_tensor) layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype())) @@ -400,7 +402,7 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod): x, w_prepack=layer.weight, w_g_s=layer.weight_scale, - weight_dtype="uint4b8", + weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128", ) return linear_out diff --git a/tests/operators/test_machete_mm.py b/tests/operators/test_machete_mm.py index 117fd7928..fafdf717d 100644 --- a/tests/operators/test_machete_mm.py +++ b/tests/operators/test_machete_mm.py @@ -64,11 +64,11 @@ def convert_uint16_to_float(in_list): not core.is_compiled_with_cuda() or get_sm_version() < 90, "machete only support sm90.", ) -class WeightOnlyLinearTestCase(unittest.TestCase): +class WeightOnlyInt4LinearTestCase(unittest.TestCase): def config(self): self.dtype = "float16" self.rtol = 1e-5 - self.atol = 1e-2 + self.atol = 1.3e-1 self.bias = False self.batch = 1 self.token = 512 @@ -77,11 +77,10 @@ class WeightOnlyLinearTestCase(unittest.TestCase): self.weight_dtype = "int4" self.static = False self.group_size = -1 + self.machete_group_size = -1 def setUp(self): self.config() - if self.dtype == "bfloat16" or self.weight_dtype == "int4": - self.atol = 1.3e-1 x = np.random.random((self.token, self.in_features)) self.x = paddle.to_tensor(x, dtype=self.dtype) if self.bias: @@ -111,29 +110,30 @@ class WeightOnlyLinearTestCase(unittest.TestCase): return out.numpy() def get_weight_only_linear_out(self): - for i in range(10): - out = Q.weight_only_linear( - self.x, - self.weight, - bias=self.bias, - weight_scale=self.weight_scale, - weight_dtype=self.weight_dtype, - group_size=self.group_size, - ) + out = Q.weight_only_linear( + self.x, + self.weight, + bias=self.bias, + weight_scale=self.weight_scale, + weight_dtype=self.weight_dtype, + group_size=self.group_size, + ) return out.numpy() def get_machete_weight_only_linear_out(self): w_q, w_s = machete_quantize_and_pack( w=self.float_weight.cuda(), atype=self.dtype, - quant_type="uint4b8", + quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128", + group_size=self.machete_group_size, ) out = machete_wint_mm( self.x, w_prepack=w_q, w_g_s=w_s, # group scales - weight_dtype="uint4b8", # weight_dtype + weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype + group_size=self.machete_group_size, ) return out.numpy() @@ -149,26 +149,94 @@ class WeightOnlyLinearTestCase(unittest.TestCase): np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol) -M = [32, 128] -K_N = [[2048, 4096]] +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_sm_version() < 90, + "machete only support sm90.", +) +class WeightOnlyInt8LinearTestCase(unittest.TestCase): + def config(self): + self.dtype = "float16" + self.rtol = 1e-5 + self.atol = 1e-1 + self.bias = False + self.batch = 1 + self.token = 512 + self.in_features = 7168 + self.out_features = 1024 + self.weight_dtype = "int8" + self.static = False + self.group_size = -1 + self.machete_group_size = 128 + def setUp(self): + self.config() + x = np.random.random((self.token, self.in_features)) + self.x = paddle.to_tensor(x, dtype=self.dtype) + if self.bias: + bias_attr = base.ParamAttr( + trainable=False, + regularizer=None, + initializer=paddle.nn.initializer.Constant(value=1.0), + ) + else: + bias_attr = None + set_default_dtype(self.dtype) + self.linear = paddle.nn.Linear(self.in_features, self.out_features, bias_attr=bias_attr) -def make_case(m, k, n): - class Case(WeightOnlyLinearTestCase): - def config(self, _m=m, _k=k, _n=n): - super().config() - self.token = m - self.in_features = k - self.out_features = n + self.bias = self.linear.bias + self.weight = self.linear.weight + self.float_weight = self.linear.weight + self.weight_scale = None - Case.name = f"WeightOnlyLinearTestCase{m}{k}{n}" - return Case + self.weight, self.weight_scale = Q.weight_quantize( + (self.float_weight.cuda() if self.weight_dtype == "int8" else self.weight.cpu()), + algo=("weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4"), + group_size=self.group_size, + ) + def get_linear_out(self): + out = self.linear(self.x) + return out.numpy() + + def get_weight_only_linear_out(self): + out = Q.weight_only_linear( + self.x, + self.weight, + bias=self.bias, + weight_scale=self.weight_scale, + weight_dtype=self.weight_dtype, + group_size=self.group_size, + ) + return out.numpy() + + def get_machete_weight_only_linear_out(self): + w_q, w_s = machete_quantize_and_pack( + w=self.float_weight.cuda(), + atype=self.dtype, + quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128", + group_size=self.machete_group_size, + ) + + out = machete_wint_mm( + self.x, + w_prepack=w_q, + w_g_s=w_s, # group scales + weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype + group_size=self.machete_group_size, + ) + return out.numpy() + + def test_weight_only_linear(self): + out_expect = self.get_linear_out() + # out_paddle = self.get_weight_only_linear_out() + out_machete = self.get_machete_weight_only_linear_out() + + if self.dtype == "bfloat16": + # out_paddle = convert_uint16_to_float(out_paddle) + out_expect = convert_uint16_to_float(out_expect) + out_machete = convert_uint16_to_float(out_machete) + np.testing.assert_allclose(out_expect, out_machete, rtol=self.rtol, atol=self.atol) -for k, n in K_N: - for m in M: - cls = make_case(m, k, n) - globals()[cls.name] = cls if __name__ == "__main__": unittest.main()