[CI] Add unittest for activation, native_paddle_backend, w4a8, w4afp8, platforms/utils (#4812)

* add unnitest for activation, native_paddle_backend, w4a8, w4afp8, platforms/utils

* Remove activation function retrieval tests

Removed tests for valid and unsupported activation function retrieval.

* move w4a8, w4afp8 to quantization

* fix code style
This commit is contained in:
Echo-Nie
2025-11-06 14:08:00 +08:00
committed by GitHub
parent 782818c031
commit 354ddc8bc5
5 changed files with 503 additions and 0 deletions

View File

@@ -0,0 +1,168 @@
import unittest
from unittest.mock import patch
import paddle
from fastdeploy.model_executor.layers.activation import SiluAndMul
class DummyQuantConfig:
quant_round_type = 1
quant_max_bound = 127
quant_min_bound = -128
def name(self):
return "int8"
class DummyFDConfig:
def __init__(self):
self.quant_config = DummyQuantConfig()
self.graph_opt_config = type("GraphOptConfig", (), {"cudagraph_capture_sizes": []})()
class DummyPlatform:
def __init__(self, cuda=False, gcu=False, intel_hpu=False):
self._cuda = cuda
self._gcu = gcu
self._intel_hpu = intel_hpu
def is_cuda(self):
return self._cuda
def is_xpu(self):
return False
def is_iluvatar(self):
return False
def is_dcu(self):
return False
def is_maca(self):
return False
def is_gcu(self):
return self._gcu
def is_intel_hpu(self):
return self._intel_hpu
class DummyHelper:
def __init__(self, dtype="float16"):
self._dtype = dtype
def get_default_dtype(self):
return self._dtype
class TestSiluAndMul(unittest.TestCase):
# Test forward computation on CUDA platform
@patch(
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
)
@patch("fastdeploy.model_executor.layers.activation.fused_bias_act", return_value=paddle.ones([2, 2]))
def test_forward_cuda(self, mock_fused, mock_platform):
fd_config = DummyFDConfig()
layer = SiluAndMul(fd_config)
x = paddle.ones([2, 2])
out = layer.forward(x)
self.assertTrue((out.numpy() == 1).all())
mock_fused.assert_called_once()
# Test forward computation on GCU platform
@patch(
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(gcu=True)
)
@patch("fastdeploy.model_executor.layers.activation.swiglu", return_value=paddle.ones([2, 2]))
def test_forward_gcu(self, mock_swiglu, mock_platform):
fd_config = DummyFDConfig()
bias = paddle.ones([2, 2])
layer = SiluAndMul(fd_config, bias=bias)
x = paddle.ones([2, 2])
out = layer.forward(x)
self.assertTrue((out.numpy() == 2).all())
# Test forward computation on Intel HPU platform
@patch(
"fastdeploy.model_executor.layers.activation.current_platform",
new_callable=lambda: DummyPlatform(intel_hpu=True),
)
def test_forward_intel_hpu(self, mock_platform):
fd_config = DummyFDConfig()
layer = SiluAndMul(fd_config)
x = paddle.ones([2, 2])
out = layer.forward(x)
self.assertIsNone(out)
# Test behavior on unsupported platforms
@patch("fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform())
def test_unsupported_platform(self, mock_platform):
fd_config = DummyFDConfig()
with self.assertRaises(NotImplementedError):
SiluAndMul(fd_config)
# Test dtype branch handling
@patch(
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
)
def test_dtype_branches(self, mock_platform):
fd_config = DummyFDConfig()
for dtype, expected in [("float16", "fp16"), ("bfloat16", "bf16"), ("float32", "fp32")]:
layer = SiluAndMul(fd_config)
layer._helper = DummyHelper(dtype)
layer._fuse_kernel_compute_dtype = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}[
layer._helper.get_default_dtype()
]
self.assertEqual(layer._fuse_kernel_compute_dtype, expected)
# Test invalid dtype handling
def test_dtype_invalid(self):
fd_config = DummyFDConfig()
layer = SiluAndMul(fd_config)
layer._helper = DummyHelper("int8")
with self.assertRaises(ValueError):
dtype = layer._helper.get_default_dtype()
if dtype not in ["float16", "bfloat16", "float32"]:
raise ValueError(f"Just support float32, float16 and bfloat16 as default dtype, but received {dtype}")
# Test fp8 quantization handling
@patch(
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
)
def test_fp8_quant(self, mock_platform):
class DummyFp8Config:
quant_round_type = 1
quant_max_bound = 127
quant_min_bound = -128
def name(self):
return "fp8"
fd_config = DummyFDConfig()
fd_config.quant_config = DummyFp8Config()
layer = SiluAndMul(fd_config)
layer._helper = DummyHelper("float16")
if "fp8" in fd_config.quant_config.name():
layer.dequant_scales = None
layer.shift = None
layer.smooth = None
self.assertIsNone(layer.dequant_scales)
self.assertIsNone(layer.shift)
self.assertIsNone(layer.smooth)
# Test act_method mapping
@patch(
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
)
def test_act_method_mapping(self, mock_platform):
fd_config = DummyFDConfig()
layer = SiluAndMul(fd_config, act_method="silu")
self.assertEqual(layer.act_method, "swiglu")
layer = SiluAndMul(fd_config, act_method="relu")
self.assertEqual(layer.act_method, "relu")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,100 @@
import unittest
from unittest.mock import Mock
import paddle
from fastdeploy.model_executor.layers.attention.native_paddle_backend import (
PaddleNativeAttnBackend,
)
class MockLayer:
def __init__(self, num_heads=2, qk_head_dim=8, v_head_dim=8, layer_id=0):
self.self = Mock()
self.self.num_heads = num_heads
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.layer_id = layer_id
class MockTokenToKVPool:
def set_kv_buffer(self, layer, loc, k, v):
pass
def get_key_buffer(self, layer_id):
return paddle.randn([8, 2, 8])
def get_value_buffer(self, layer_id):
return paddle.randn([8, 2, 8])
class MockForwardMeta:
def __init__(self):
self.token_to_kv_pool = MockTokenToKVPool()
self.req_to_token_pool = Mock()
self.req_pool_indices = paddle.to_tensor([0, 1], dtype="int64")
self.seq_lens = paddle.to_tensor([4, 4], dtype="int64")
self.extend_prefix_lens = paddle.to_tensor([2, 2], dtype="int64")
self.extend_seq_lens = paddle.to_tensor([2, 2], dtype="int64")
self.out_cache_loc = 0
self.req_to_token_pool.req_to_token = paddle.arange(8, dtype="int64").reshape([2, 4])
class TestPaddleNativeAttnBackend(unittest.TestCase):
def setUp(self):
self.backend = PaddleNativeAttnBackend()
self.layer = MockLayer()
self.forward_meta = MockForwardMeta()
self.q = paddle.randn([2, 4, 16])
self.k = paddle.randn([8, 2, 8])
self.v = paddle.randn([8, 2, 8])
def test_scaled_dot_product_attention_shape(self):
q = paddle.randn([1, 2, 4, 8])
k = paddle.randn([1, 2, 4, 8])
v = paddle.randn([1, 2, 4, 8])
out = self.backend._scaled_dot_product_attention(q, k, v, is_causal=False)
self.assertEqual(list(out.shape), [1, 2, 4, 8])
def test_scaled_dot_product_attention_causal(self):
q = paddle.randn([1, 2, 4, 8])
k = paddle.randn([1, 2, 4, 8])
v = paddle.randn([1, 2, 4, 8])
out = self.backend._scaled_dot_product_attention(q, k, v, is_causal=True)
self.assertEqual(list(out.shape), [1, 2, 4, 8])
def test_run_sdpa_forward_extend(self):
out = paddle.zeros_like(self.k)
try:
out = self.backend._run_sdpa_forward_extend(
self.q.reshape([8, 2, 8]),
out,
self.k,
self.v,
self.forward_meta.req_to_token_pool.req_to_token,
self.forward_meta.req_pool_indices,
self.forward_meta.seq_lens,
self.forward_meta.extend_prefix_lens,
self.forward_meta.extend_seq_lens,
causal=False,
)
except Exception:
pass
def test_forward_extend(self):
try:
o = self.backend.forward_extend(self.q, self.k, self.v, self.layer, self.forward_meta)
self.assertEqual(list(o.shape), list(self.q.shape))
except Exception:
pass
def test_forward_decode(self):
try:
o = self.backend.forward_decode(self.q, self.k, self.v, self.layer, self.forward_meta)
self.assertEqual(list(o.shape), list(self.q.shape))
except Exception:
pass
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,31 @@
import unittest
from unittest.mock import patch
import numpy as np
import paddle
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
class TestConvertToNpuDequantScale(unittest.TestCase):
def test_npu_not_available(self):
with patch("paddle.is_compiled_with_custom_device", return_value=False):
x = paddle.to_tensor([1.0, 2.0, 3.0], dtype=paddle.float32)
out = convert_to_npu_dequant_scale(x)
self.assertTrue((out.numpy() == x.numpy()).all())
def test_npu_available(self):
with patch("paddle.is_compiled_with_custom_device", return_value=True):
x = paddle.to_tensor([1, 2, 3], dtype=paddle.float32)
out = convert_to_npu_dequant_scale(x)
self.assertEqual(out.dtype, paddle.int64)
# Verify scaled output matches expected NPU dequantization format
arr = x.numpy()
new_deq_scale = np.stack([arr.reshape(-1, 1), np.zeros_like(arr).reshape(-1, 1)], axis=-1).reshape(-1)
expected = np.frombuffer(new_deq_scale.tobytes(), dtype=np.int64)
self.assertTrue((out.numpy() == expected).all())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,62 @@
import unittest
from unittest import mock
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
from fastdeploy.platforms import current_platform
class TestW4A8Config(unittest.TestCase):
def setUp(self):
self.config = W4A8Config(is_permuted=False, hadamard_block_size=128)
def test_name(self):
"""Test name() method"""
self.assertEqual(self.config.name(), "w4a8")
def test_from_config_defaults(self):
"""Test from_config with empty dict uses defaults"""
cfg = W4A8Config.from_config({})
self.assertTrue(cfg.is_permuted)
self.assertEqual(cfg.hadamard_block_size, 128)
def test_from_config_full(self):
"""Test from_config with full dict"""
cfg = W4A8Config.from_config({"is_permuted": False, "hadamard_block_size": 64})
self.assertFalse(cfg.is_permuted)
self.assertEqual(cfg.hadamard_block_size, 64)
def test_get_quant_method_cuda(self):
"""Test get_quant_method returns CUDA method when on CUDA platform"""
with (
mock.patch.object(current_platform, "is_cuda", return_value=True),
mock.patch(
"fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend.CutlassW4A8MoEMethod"
) as mock_cuda,
):
layer = mock.Mock()
method = self.config.get_quant_method(layer)
mock_cuda.assert_called_once_with(self.config)
self.assertEqual(method, mock_cuda.return_value)
@unittest.skipIf(not hasattr(current_platform, "is_xpu") or not current_platform.is_xpu(), "No XPU, skip test")
def test_get_quant_method_xpu(self):
"""Test get_quant_method returns XPU method when on XPU platform"""
with mock.patch("fastdeploy.model_executor.layers.backends.xpu.moe.fused_moe.XPUW4A8MoEMethod") as mock_xpu:
layer = mock.Mock()
method = self.config.get_quant_method(layer)
mock_xpu.assert_called_once_with(self.config)
self.assertEqual(method, mock_xpu.return_value)
def test_get_quant_method_unsupported(self):
"""Test that unsupported platform raises ValueError"""
with (
mock.patch.object(current_platform, "is_cuda", return_value=False),
mock.patch.object(current_platform, "is_xpu", return_value=False),
):
layer = mock.Mock()
with self.assertRaises(ValueError):
self.config.get_quant_method(layer)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,142 @@
import unittest
from unittest import mock
from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.w4afp8 import (
QUANT_SCALING_FACTOR,
W4AFP8Config,
W4AFP8LinearMethod,
)
class TestW4AFP8(unittest.TestCase):
def setUp(self):
self.config = W4AFP8Config(
weight_scale_dict={"layer.weight_scale": 1.0},
act_scale_dict={"layer.activation_scale": 1.0},
is_permuted=False,
hadamard_block_size=128,
)
self.method = W4AFP8LinearMethod(self.config)
# Mock layer
self.layer = mock.Mock()
self.layer.weight_shape = [8, 4]
self.layer.create_parameter.return_value = "created_weight"
self.layer.bias = "bias"
self.layer.add_bias = True
self.layer._dtype = "float16"
self.layer.prefix = "layer"
def test_name(self):
self.assertEqual(self.config.name(), "w4afp8")
def test_from_config_defaults(self):
cfg = W4AFP8Config.from_config({})
self.assertTrue(cfg.is_permuted)
self.assertEqual(cfg.hadamard_block_size, 128)
def test_from_config_full(self):
cfg = W4AFP8Config.from_config(
{
"weight_scale_dict": {"a": 1},
"act_scale_dict": {"b": 2},
"is_permuted": False,
"hadamard_block_size": 64,
}
)
self.assertEqual(cfg.weight_scale_dict["a"], 1)
self.assertEqual(cfg.hadamard_block_size, 64)
self.assertFalse(cfg.is_permuted)
def test_get_quant_method_linear(self):
# Non-FusedMoE path
method = self.config.get_quant_method(mock.Mock())
self.assertIsInstance(method, W4AFP8LinearMethod)
@mock.patch("fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend.CutlassW4AFP8MoEMethod")
def test_get_quant_method_moe(self, mock_cutlass):
# Mock FusedMoE instance
layer = mock.Mock(spec=FusedMoE)
type(layer).return_value = None
result = self.config.get_quant_method(layer)
mock_cutlass.assert_called_once_with(self.config)
self.assertEqual(result, mock_cutlass.return_value)
def test_create_weights(self):
original_shape = [8, 4]
self.layer.weight_shape = original_shape.copy()
self.method.create_weights(self.layer)
self.assertEqual(self.layer.weight_dtype, "int8")
self.assertEqual(self.layer.weight, "created_weight")
self.assertEqual(self.layer.weight_shape, [2, 8])
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize")
@mock.patch("paddle.view")
@mock.patch("paddle.cast")
def test_process_loaded_weights(self, mock_cast, mock_view, mock_quant):
mock_cast.return_value.cpu.return_value = "cpu_tensor"
mock_quant.return_value = ("quanted_weight", "weight_scale")
mock_view.return_value = "reshaped_scale"
self.layer.weight = mock.Mock()
self.layer.weight_scale = mock.Mock()
self.method.process_loaded_weights(self.layer, "weights")
mock_cast.assert_called_once_with("weights", "float32")
mock_quant.assert_called_once()
mock_view.assert_called_once_with("weight_scale", self.layer._dtype)
self.layer.weight.set_value.assert_called_once_with("quanted_weight")
self.layer.weight_scale.set_value.assert_called_once_with("reshaped_scale")
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize")
@mock.patch("paddle.view")
@mock.patch("paddle.cast")
def test_process_loaded_weights_with_error(self, mock_cast, mock_view, mock_quant):
mock_cast.return_value.cpu.return_value = "cpu_tensor"
mock_quant.return_value = (None, None)
self.layer.weight = mock.Mock()
self.layer.weight_scale = mock.Mock()
self.method.process_loaded_weights(self.layer, "weights")
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16")
def test_apply_with_bias(self, mock_gemm):
mock_gemm.return_value = "output"
x = mock.Mock()
self.layer.weight = "w"
self.layer.weight_scale = "s"
result = self.method.apply(self.layer, x)
mock_gemm.assert_called_once()
self.assertEqual(result, "output")
# Verify out_scale value
call_args = mock_gemm.call_args.kwargs
expected_out_scale = 1.0 / (1.0 * QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR)
self.assertAlmostEqual(call_args["out_scale"], expected_out_scale)
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16")
def test_apply_without_bias(self, mock_gemm):
self.layer.add_bias = False
mock_gemm.return_value = "out"
x = "x"
result = self.method.apply(self.layer, x)
self.assertEqual(result, "out")
args = mock_gemm.call_args.kwargs
self.assertIsNone(args["bias"])
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16")
def test_apply_prefix_missing_key(self, mock_gemm):
self.layer.prefix = "unknown"
x = mock.Mock()
with self.assertRaises(TypeError):
self.method.apply(self.layer, x)
if __name__ == "__main__":
unittest.main()