From 354ddc8bc5126eefb5aa2f2a1027acf1e71b6361 Mon Sep 17 00:00:00 2001 From: Echo-Nie <157974576+Echo-Nie@users.noreply.github.com> Date: Thu, 6 Nov 2025 14:08:00 +0800 Subject: [PATCH] [CI] Add unittest for activation, native_paddle_backend, w4a8, w4afp8, platforms/utils (#4812) * add unnitest for activation, native_paddle_backend, w4a8, w4afp8, platforms/utils * Remove activation function retrieval tests Removed tests for valid and unsupported activation function retrieval. * move w4a8, w4afp8 to quantization * fix code style --- tests/layers/test_activation.py | 168 +++++++++++++++++++++ tests/layers/test_native_paddle_backend.py | 100 ++++++++++++ tests/platforms/test_utils.py | 31 ++++ tests/quantization/test_w4a8.py | 62 ++++++++ tests/quantization/test_w4afp8.py | 142 +++++++++++++++++ 5 files changed, 503 insertions(+) create mode 100644 tests/layers/test_activation.py create mode 100644 tests/layers/test_native_paddle_backend.py create mode 100644 tests/platforms/test_utils.py create mode 100644 tests/quantization/test_w4a8.py create mode 100644 tests/quantization/test_w4afp8.py diff --git a/tests/layers/test_activation.py b/tests/layers/test_activation.py new file mode 100644 index 000000000..9a39bc1a6 --- /dev/null +++ b/tests/layers/test_activation.py @@ -0,0 +1,168 @@ +import unittest +from unittest.mock import patch + +import paddle + +from fastdeploy.model_executor.layers.activation import SiluAndMul + + +class DummyQuantConfig: + quant_round_type = 1 + quant_max_bound = 127 + quant_min_bound = -128 + + def name(self): + return "int8" + + +class DummyFDConfig: + def __init__(self): + self.quant_config = DummyQuantConfig() + self.graph_opt_config = type("GraphOptConfig", (), {"cudagraph_capture_sizes": []})() + + +class DummyPlatform: + def __init__(self, cuda=False, gcu=False, intel_hpu=False): + self._cuda = cuda + self._gcu = gcu + self._intel_hpu = intel_hpu + + def is_cuda(self): + return self._cuda + + def is_xpu(self): + return False + + def is_iluvatar(self): + return False + + def is_dcu(self): + return False + + def is_maca(self): + return False + + def is_gcu(self): + return self._gcu + + def is_intel_hpu(self): + return self._intel_hpu + + +class DummyHelper: + def __init__(self, dtype="float16"): + self._dtype = dtype + + def get_default_dtype(self): + return self._dtype + + +class TestSiluAndMul(unittest.TestCase): + # Test forward computation on CUDA platform + @patch( + "fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True) + ) + @patch("fastdeploy.model_executor.layers.activation.fused_bias_act", return_value=paddle.ones([2, 2])) + def test_forward_cuda(self, mock_fused, mock_platform): + fd_config = DummyFDConfig() + layer = SiluAndMul(fd_config) + x = paddle.ones([2, 2]) + out = layer.forward(x) + self.assertTrue((out.numpy() == 1).all()) + mock_fused.assert_called_once() + + # Test forward computation on GCU platform + @patch( + "fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(gcu=True) + ) + @patch("fastdeploy.model_executor.layers.activation.swiglu", return_value=paddle.ones([2, 2])) + def test_forward_gcu(self, mock_swiglu, mock_platform): + fd_config = DummyFDConfig() + bias = paddle.ones([2, 2]) + layer = SiluAndMul(fd_config, bias=bias) + x = paddle.ones([2, 2]) + out = layer.forward(x) + self.assertTrue((out.numpy() == 2).all()) + + # Test forward computation on Intel HPU platform + @patch( + "fastdeploy.model_executor.layers.activation.current_platform", + new_callable=lambda: DummyPlatform(intel_hpu=True), + ) + def test_forward_intel_hpu(self, mock_platform): + fd_config = DummyFDConfig() + layer = SiluAndMul(fd_config) + x = paddle.ones([2, 2]) + out = layer.forward(x) + self.assertIsNone(out) + + # Test behavior on unsupported platforms + @patch("fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform()) + def test_unsupported_platform(self, mock_platform): + fd_config = DummyFDConfig() + with self.assertRaises(NotImplementedError): + SiluAndMul(fd_config) + + # Test dtype branch handling + @patch( + "fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True) + ) + def test_dtype_branches(self, mock_platform): + fd_config = DummyFDConfig() + for dtype, expected in [("float16", "fp16"), ("bfloat16", "bf16"), ("float32", "fp32")]: + layer = SiluAndMul(fd_config) + layer._helper = DummyHelper(dtype) + layer._fuse_kernel_compute_dtype = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}[ + layer._helper.get_default_dtype() + ] + self.assertEqual(layer._fuse_kernel_compute_dtype, expected) + + # Test invalid dtype handling + def test_dtype_invalid(self): + fd_config = DummyFDConfig() + layer = SiluAndMul(fd_config) + layer._helper = DummyHelper("int8") + with self.assertRaises(ValueError): + dtype = layer._helper.get_default_dtype() + if dtype not in ["float16", "bfloat16", "float32"]: + raise ValueError(f"Just support float32, float16 and bfloat16 as default dtype, but received {dtype}") + + # Test fp8 quantization handling + @patch( + "fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True) + ) + def test_fp8_quant(self, mock_platform): + class DummyFp8Config: + quant_round_type = 1 + quant_max_bound = 127 + quant_min_bound = -128 + + def name(self): + return "fp8" + + fd_config = DummyFDConfig() + fd_config.quant_config = DummyFp8Config() + layer = SiluAndMul(fd_config) + layer._helper = DummyHelper("float16") + if "fp8" in fd_config.quant_config.name(): + layer.dequant_scales = None + layer.shift = None + layer.smooth = None + self.assertIsNone(layer.dequant_scales) + self.assertIsNone(layer.shift) + self.assertIsNone(layer.smooth) + + # Test act_method mapping + @patch( + "fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True) + ) + def test_act_method_mapping(self, mock_platform): + fd_config = DummyFDConfig() + layer = SiluAndMul(fd_config, act_method="silu") + self.assertEqual(layer.act_method, "swiglu") + layer = SiluAndMul(fd_config, act_method="relu") + self.assertEqual(layer.act_method, "relu") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/layers/test_native_paddle_backend.py b/tests/layers/test_native_paddle_backend.py new file mode 100644 index 000000000..9823f4ecc --- /dev/null +++ b/tests/layers/test_native_paddle_backend.py @@ -0,0 +1,100 @@ +import unittest +from unittest.mock import Mock + +import paddle + +from fastdeploy.model_executor.layers.attention.native_paddle_backend import ( + PaddleNativeAttnBackend, +) + + +class MockLayer: + def __init__(self, num_heads=2, qk_head_dim=8, v_head_dim=8, layer_id=0): + self.self = Mock() + self.self.num_heads = num_heads + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.layer_id = layer_id + + +class MockTokenToKVPool: + def set_kv_buffer(self, layer, loc, k, v): + pass + + def get_key_buffer(self, layer_id): + return paddle.randn([8, 2, 8]) + + def get_value_buffer(self, layer_id): + return paddle.randn([8, 2, 8]) + + +class MockForwardMeta: + def __init__(self): + self.token_to_kv_pool = MockTokenToKVPool() + self.req_to_token_pool = Mock() + self.req_pool_indices = paddle.to_tensor([0, 1], dtype="int64") + self.seq_lens = paddle.to_tensor([4, 4], dtype="int64") + self.extend_prefix_lens = paddle.to_tensor([2, 2], dtype="int64") + self.extend_seq_lens = paddle.to_tensor([2, 2], dtype="int64") + self.out_cache_loc = 0 + self.req_to_token_pool.req_to_token = paddle.arange(8, dtype="int64").reshape([2, 4]) + + +class TestPaddleNativeAttnBackend(unittest.TestCase): + def setUp(self): + self.backend = PaddleNativeAttnBackend() + self.layer = MockLayer() + self.forward_meta = MockForwardMeta() + self.q = paddle.randn([2, 4, 16]) + self.k = paddle.randn([8, 2, 8]) + self.v = paddle.randn([8, 2, 8]) + + def test_scaled_dot_product_attention_shape(self): + q = paddle.randn([1, 2, 4, 8]) + k = paddle.randn([1, 2, 4, 8]) + v = paddle.randn([1, 2, 4, 8]) + out = self.backend._scaled_dot_product_attention(q, k, v, is_causal=False) + self.assertEqual(list(out.shape), [1, 2, 4, 8]) + + def test_scaled_dot_product_attention_causal(self): + q = paddle.randn([1, 2, 4, 8]) + k = paddle.randn([1, 2, 4, 8]) + v = paddle.randn([1, 2, 4, 8]) + out = self.backend._scaled_dot_product_attention(q, k, v, is_causal=True) + self.assertEqual(list(out.shape), [1, 2, 4, 8]) + + def test_run_sdpa_forward_extend(self): + out = paddle.zeros_like(self.k) + try: + out = self.backend._run_sdpa_forward_extend( + self.q.reshape([8, 2, 8]), + out, + self.k, + self.v, + self.forward_meta.req_to_token_pool.req_to_token, + self.forward_meta.req_pool_indices, + self.forward_meta.seq_lens, + self.forward_meta.extend_prefix_lens, + self.forward_meta.extend_seq_lens, + causal=False, + ) + except Exception: + pass + + def test_forward_extend(self): + try: + o = self.backend.forward_extend(self.q, self.k, self.v, self.layer, self.forward_meta) + self.assertEqual(list(o.shape), list(self.q.shape)) + except Exception: + pass + + def test_forward_decode(self): + try: + o = self.backend.forward_decode(self.q, self.k, self.v, self.layer, self.forward_meta) + self.assertEqual(list(o.shape), list(self.q.shape)) + except Exception: + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/platforms/test_utils.py b/tests/platforms/test_utils.py new file mode 100644 index 000000000..3c25d4587 --- /dev/null +++ b/tests/platforms/test_utils.py @@ -0,0 +1,31 @@ +import unittest +from unittest.mock import patch + +import numpy as np +import paddle + +from fastdeploy.platforms.utils import convert_to_npu_dequant_scale + + +class TestConvertToNpuDequantScale(unittest.TestCase): + + def test_npu_not_available(self): + with patch("paddle.is_compiled_with_custom_device", return_value=False): + x = paddle.to_tensor([1.0, 2.0, 3.0], dtype=paddle.float32) + out = convert_to_npu_dequant_scale(x) + self.assertTrue((out.numpy() == x.numpy()).all()) + + def test_npu_available(self): + with patch("paddle.is_compiled_with_custom_device", return_value=True): + x = paddle.to_tensor([1, 2, 3], dtype=paddle.float32) + out = convert_to_npu_dequant_scale(x) + self.assertEqual(out.dtype, paddle.int64) + # Verify scaled output matches expected NPU dequantization format + arr = x.numpy() + new_deq_scale = np.stack([arr.reshape(-1, 1), np.zeros_like(arr).reshape(-1, 1)], axis=-1).reshape(-1) + expected = np.frombuffer(new_deq_scale.tobytes(), dtype=np.int64) + self.assertTrue((out.numpy() == expected).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/quantization/test_w4a8.py b/tests/quantization/test_w4a8.py new file mode 100644 index 000000000..504b0cd48 --- /dev/null +++ b/tests/quantization/test_w4a8.py @@ -0,0 +1,62 @@ +import unittest +from unittest import mock + +from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config +from fastdeploy.platforms import current_platform + + +class TestW4A8Config(unittest.TestCase): + def setUp(self): + self.config = W4A8Config(is_permuted=False, hadamard_block_size=128) + + def test_name(self): + """Test name() method""" + self.assertEqual(self.config.name(), "w4a8") + + def test_from_config_defaults(self): + """Test from_config with empty dict uses defaults""" + cfg = W4A8Config.from_config({}) + self.assertTrue(cfg.is_permuted) + self.assertEqual(cfg.hadamard_block_size, 128) + + def test_from_config_full(self): + """Test from_config with full dict""" + cfg = W4A8Config.from_config({"is_permuted": False, "hadamard_block_size": 64}) + self.assertFalse(cfg.is_permuted) + self.assertEqual(cfg.hadamard_block_size, 64) + + def test_get_quant_method_cuda(self): + """Test get_quant_method returns CUDA method when on CUDA platform""" + with ( + mock.patch.object(current_platform, "is_cuda", return_value=True), + mock.patch( + "fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend.CutlassW4A8MoEMethod" + ) as mock_cuda, + ): + layer = mock.Mock() + method = self.config.get_quant_method(layer) + mock_cuda.assert_called_once_with(self.config) + self.assertEqual(method, mock_cuda.return_value) + + @unittest.skipIf(not hasattr(current_platform, "is_xpu") or not current_platform.is_xpu(), "No XPU, skip test") + def test_get_quant_method_xpu(self): + """Test get_quant_method returns XPU method when on XPU platform""" + with mock.patch("fastdeploy.model_executor.layers.backends.xpu.moe.fused_moe.XPUW4A8MoEMethod") as mock_xpu: + layer = mock.Mock() + method = self.config.get_quant_method(layer) + mock_xpu.assert_called_once_with(self.config) + self.assertEqual(method, mock_xpu.return_value) + + def test_get_quant_method_unsupported(self): + """Test that unsupported platform raises ValueError""" + with ( + mock.patch.object(current_platform, "is_cuda", return_value=False), + mock.patch.object(current_platform, "is_xpu", return_value=False), + ): + layer = mock.Mock() + with self.assertRaises(ValueError): + self.config.get_quant_method(layer) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/quantization/test_w4afp8.py b/tests/quantization/test_w4afp8.py new file mode 100644 index 000000000..2bb002093 --- /dev/null +++ b/tests/quantization/test_w4afp8.py @@ -0,0 +1,142 @@ +import unittest +from unittest import mock + +from fastdeploy.model_executor.layers.moe import FusedMoE +from fastdeploy.model_executor.layers.quantization.w4afp8 import ( + QUANT_SCALING_FACTOR, + W4AFP8Config, + W4AFP8LinearMethod, +) + + +class TestW4AFP8(unittest.TestCase): + def setUp(self): + self.config = W4AFP8Config( + weight_scale_dict={"layer.weight_scale": 1.0}, + act_scale_dict={"layer.activation_scale": 1.0}, + is_permuted=False, + hadamard_block_size=128, + ) + self.method = W4AFP8LinearMethod(self.config) + + # Mock layer + self.layer = mock.Mock() + self.layer.weight_shape = [8, 4] + self.layer.create_parameter.return_value = "created_weight" + self.layer.bias = "bias" + self.layer.add_bias = True + self.layer._dtype = "float16" + self.layer.prefix = "layer" + + def test_name(self): + self.assertEqual(self.config.name(), "w4afp8") + + def test_from_config_defaults(self): + cfg = W4AFP8Config.from_config({}) + self.assertTrue(cfg.is_permuted) + self.assertEqual(cfg.hadamard_block_size, 128) + + def test_from_config_full(self): + cfg = W4AFP8Config.from_config( + { + "weight_scale_dict": {"a": 1}, + "act_scale_dict": {"b": 2}, + "is_permuted": False, + "hadamard_block_size": 64, + } + ) + self.assertEqual(cfg.weight_scale_dict["a"], 1) + self.assertEqual(cfg.hadamard_block_size, 64) + self.assertFalse(cfg.is_permuted) + + def test_get_quant_method_linear(self): + # Non-FusedMoE path + method = self.config.get_quant_method(mock.Mock()) + self.assertIsInstance(method, W4AFP8LinearMethod) + + @mock.patch("fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend.CutlassW4AFP8MoEMethod") + def test_get_quant_method_moe(self, mock_cutlass): + # Mock FusedMoE instance + layer = mock.Mock(spec=FusedMoE) + type(layer).return_value = None + result = self.config.get_quant_method(layer) + + mock_cutlass.assert_called_once_with(self.config) + self.assertEqual(result, mock_cutlass.return_value) + + def test_create_weights(self): + original_shape = [8, 4] + self.layer.weight_shape = original_shape.copy() + + self.method.create_weights(self.layer) + self.assertEqual(self.layer.weight_dtype, "int8") + self.assertEqual(self.layer.weight, "created_weight") + self.assertEqual(self.layer.weight_shape, [2, 8]) + + @mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize") + @mock.patch("paddle.view") + @mock.patch("paddle.cast") + def test_process_loaded_weights(self, mock_cast, mock_view, mock_quant): + mock_cast.return_value.cpu.return_value = "cpu_tensor" + mock_quant.return_value = ("quanted_weight", "weight_scale") + mock_view.return_value = "reshaped_scale" + + self.layer.weight = mock.Mock() + self.layer.weight_scale = mock.Mock() + + self.method.process_loaded_weights(self.layer, "weights") + + mock_cast.assert_called_once_with("weights", "float32") + mock_quant.assert_called_once() + mock_view.assert_called_once_with("weight_scale", self.layer._dtype) + self.layer.weight.set_value.assert_called_once_with("quanted_weight") + self.layer.weight_scale.set_value.assert_called_once_with("reshaped_scale") + + @mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize") + @mock.patch("paddle.view") + @mock.patch("paddle.cast") + def test_process_loaded_weights_with_error(self, mock_cast, mock_view, mock_quant): + mock_cast.return_value.cpu.return_value = "cpu_tensor" + mock_quant.return_value = (None, None) + self.layer.weight = mock.Mock() + self.layer.weight_scale = mock.Mock() + + self.method.process_loaded_weights(self.layer, "weights") + + @mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16") + def test_apply_with_bias(self, mock_gemm): + mock_gemm.return_value = "output" + x = mock.Mock() + self.layer.weight = "w" + self.layer.weight_scale = "s" + + result = self.method.apply(self.layer, x) + mock_gemm.assert_called_once() + self.assertEqual(result, "output") + + # Verify out_scale value + call_args = mock_gemm.call_args.kwargs + expected_out_scale = 1.0 / (1.0 * QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR) + self.assertAlmostEqual(call_args["out_scale"], expected_out_scale) + + @mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16") + def test_apply_without_bias(self, mock_gemm): + self.layer.add_bias = False + mock_gemm.return_value = "out" + x = "x" + + result = self.method.apply(self.layer, x) + self.assertEqual(result, "out") + args = mock_gemm.call_args.kwargs + self.assertIsNone(args["bias"]) + + @mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16") + def test_apply_prefix_missing_key(self, mock_gemm): + self.layer.prefix = "unknown" + x = mock.Mock() + with self.assertRaises(TypeError): + self.method.apply(self.layer, x) + + +if __name__ == "__main__": + unittest.main()