mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[CI] Add unittest for activation, native_paddle_backend, w4a8, w4afp8, platforms/utils (#4812)
* add unnitest for activation, native_paddle_backend, w4a8, w4afp8, platforms/utils * Remove activation function retrieval tests Removed tests for valid and unsupported activation function retrieval. * move w4a8, w4afp8 to quantization * fix code style
This commit is contained in:
168
tests/layers/test_activation.py
Normal file
168
tests/layers/test_activation.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
|
||||
class DummyQuantConfig:
|
||||
quant_round_type = 1
|
||||
quant_max_bound = 127
|
||||
quant_min_bound = -128
|
||||
|
||||
def name(self):
|
||||
return "int8"
|
||||
|
||||
|
||||
class DummyFDConfig:
|
||||
def __init__(self):
|
||||
self.quant_config = DummyQuantConfig()
|
||||
self.graph_opt_config = type("GraphOptConfig", (), {"cudagraph_capture_sizes": []})()
|
||||
|
||||
|
||||
class DummyPlatform:
|
||||
def __init__(self, cuda=False, gcu=False, intel_hpu=False):
|
||||
self._cuda = cuda
|
||||
self._gcu = gcu
|
||||
self._intel_hpu = intel_hpu
|
||||
|
||||
def is_cuda(self):
|
||||
return self._cuda
|
||||
|
||||
def is_xpu(self):
|
||||
return False
|
||||
|
||||
def is_iluvatar(self):
|
||||
return False
|
||||
|
||||
def is_dcu(self):
|
||||
return False
|
||||
|
||||
def is_maca(self):
|
||||
return False
|
||||
|
||||
def is_gcu(self):
|
||||
return self._gcu
|
||||
|
||||
def is_intel_hpu(self):
|
||||
return self._intel_hpu
|
||||
|
||||
|
||||
class DummyHelper:
|
||||
def __init__(self, dtype="float16"):
|
||||
self._dtype = dtype
|
||||
|
||||
def get_default_dtype(self):
|
||||
return self._dtype
|
||||
|
||||
|
||||
class TestSiluAndMul(unittest.TestCase):
|
||||
# Test forward computation on CUDA platform
|
||||
@patch(
|
||||
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
|
||||
)
|
||||
@patch("fastdeploy.model_executor.layers.activation.fused_bias_act", return_value=paddle.ones([2, 2]))
|
||||
def test_forward_cuda(self, mock_fused, mock_platform):
|
||||
fd_config = DummyFDConfig()
|
||||
layer = SiluAndMul(fd_config)
|
||||
x = paddle.ones([2, 2])
|
||||
out = layer.forward(x)
|
||||
self.assertTrue((out.numpy() == 1).all())
|
||||
mock_fused.assert_called_once()
|
||||
|
||||
# Test forward computation on GCU platform
|
||||
@patch(
|
||||
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(gcu=True)
|
||||
)
|
||||
@patch("fastdeploy.model_executor.layers.activation.swiglu", return_value=paddle.ones([2, 2]))
|
||||
def test_forward_gcu(self, mock_swiglu, mock_platform):
|
||||
fd_config = DummyFDConfig()
|
||||
bias = paddle.ones([2, 2])
|
||||
layer = SiluAndMul(fd_config, bias=bias)
|
||||
x = paddle.ones([2, 2])
|
||||
out = layer.forward(x)
|
||||
self.assertTrue((out.numpy() == 2).all())
|
||||
|
||||
# Test forward computation on Intel HPU platform
|
||||
@patch(
|
||||
"fastdeploy.model_executor.layers.activation.current_platform",
|
||||
new_callable=lambda: DummyPlatform(intel_hpu=True),
|
||||
)
|
||||
def test_forward_intel_hpu(self, mock_platform):
|
||||
fd_config = DummyFDConfig()
|
||||
layer = SiluAndMul(fd_config)
|
||||
x = paddle.ones([2, 2])
|
||||
out = layer.forward(x)
|
||||
self.assertIsNone(out)
|
||||
|
||||
# Test behavior on unsupported platforms
|
||||
@patch("fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform())
|
||||
def test_unsupported_platform(self, mock_platform):
|
||||
fd_config = DummyFDConfig()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
SiluAndMul(fd_config)
|
||||
|
||||
# Test dtype branch handling
|
||||
@patch(
|
||||
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
|
||||
)
|
||||
def test_dtype_branches(self, mock_platform):
|
||||
fd_config = DummyFDConfig()
|
||||
for dtype, expected in [("float16", "fp16"), ("bfloat16", "bf16"), ("float32", "fp32")]:
|
||||
layer = SiluAndMul(fd_config)
|
||||
layer._helper = DummyHelper(dtype)
|
||||
layer._fuse_kernel_compute_dtype = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}[
|
||||
layer._helper.get_default_dtype()
|
||||
]
|
||||
self.assertEqual(layer._fuse_kernel_compute_dtype, expected)
|
||||
|
||||
# Test invalid dtype handling
|
||||
def test_dtype_invalid(self):
|
||||
fd_config = DummyFDConfig()
|
||||
layer = SiluAndMul(fd_config)
|
||||
layer._helper = DummyHelper("int8")
|
||||
with self.assertRaises(ValueError):
|
||||
dtype = layer._helper.get_default_dtype()
|
||||
if dtype not in ["float16", "bfloat16", "float32"]:
|
||||
raise ValueError(f"Just support float32, float16 and bfloat16 as default dtype, but received {dtype}")
|
||||
|
||||
# Test fp8 quantization handling
|
||||
@patch(
|
||||
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
|
||||
)
|
||||
def test_fp8_quant(self, mock_platform):
|
||||
class DummyFp8Config:
|
||||
quant_round_type = 1
|
||||
quant_max_bound = 127
|
||||
quant_min_bound = -128
|
||||
|
||||
def name(self):
|
||||
return "fp8"
|
||||
|
||||
fd_config = DummyFDConfig()
|
||||
fd_config.quant_config = DummyFp8Config()
|
||||
layer = SiluAndMul(fd_config)
|
||||
layer._helper = DummyHelper("float16")
|
||||
if "fp8" in fd_config.quant_config.name():
|
||||
layer.dequant_scales = None
|
||||
layer.shift = None
|
||||
layer.smooth = None
|
||||
self.assertIsNone(layer.dequant_scales)
|
||||
self.assertIsNone(layer.shift)
|
||||
self.assertIsNone(layer.smooth)
|
||||
|
||||
# Test act_method mapping
|
||||
@patch(
|
||||
"fastdeploy.model_executor.layers.activation.current_platform", new_callable=lambda: DummyPlatform(cuda=True)
|
||||
)
|
||||
def test_act_method_mapping(self, mock_platform):
|
||||
fd_config = DummyFDConfig()
|
||||
layer = SiluAndMul(fd_config, act_method="silu")
|
||||
self.assertEqual(layer.act_method, "swiglu")
|
||||
layer = SiluAndMul(fd_config, act_method="relu")
|
||||
self.assertEqual(layer.act_method, "relu")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
100
tests/layers/test_native_paddle_backend.py
Normal file
100
tests/layers/test_native_paddle_backend.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.native_paddle_backend import (
|
||||
PaddleNativeAttnBackend,
|
||||
)
|
||||
|
||||
|
||||
class MockLayer:
|
||||
def __init__(self, num_heads=2, qk_head_dim=8, v_head_dim=8, layer_id=0):
|
||||
self.self = Mock()
|
||||
self.self.num_heads = num_heads
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.layer_id = layer_id
|
||||
|
||||
|
||||
class MockTokenToKVPool:
|
||||
def set_kv_buffer(self, layer, loc, k, v):
|
||||
pass
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
return paddle.randn([8, 2, 8])
|
||||
|
||||
def get_value_buffer(self, layer_id):
|
||||
return paddle.randn([8, 2, 8])
|
||||
|
||||
|
||||
class MockForwardMeta:
|
||||
def __init__(self):
|
||||
self.token_to_kv_pool = MockTokenToKVPool()
|
||||
self.req_to_token_pool = Mock()
|
||||
self.req_pool_indices = paddle.to_tensor([0, 1], dtype="int64")
|
||||
self.seq_lens = paddle.to_tensor([4, 4], dtype="int64")
|
||||
self.extend_prefix_lens = paddle.to_tensor([2, 2], dtype="int64")
|
||||
self.extend_seq_lens = paddle.to_tensor([2, 2], dtype="int64")
|
||||
self.out_cache_loc = 0
|
||||
self.req_to_token_pool.req_to_token = paddle.arange(8, dtype="int64").reshape([2, 4])
|
||||
|
||||
|
||||
class TestPaddleNativeAttnBackend(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.backend = PaddleNativeAttnBackend()
|
||||
self.layer = MockLayer()
|
||||
self.forward_meta = MockForwardMeta()
|
||||
self.q = paddle.randn([2, 4, 16])
|
||||
self.k = paddle.randn([8, 2, 8])
|
||||
self.v = paddle.randn([8, 2, 8])
|
||||
|
||||
def test_scaled_dot_product_attention_shape(self):
|
||||
q = paddle.randn([1, 2, 4, 8])
|
||||
k = paddle.randn([1, 2, 4, 8])
|
||||
v = paddle.randn([1, 2, 4, 8])
|
||||
out = self.backend._scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
self.assertEqual(list(out.shape), [1, 2, 4, 8])
|
||||
|
||||
def test_scaled_dot_product_attention_causal(self):
|
||||
q = paddle.randn([1, 2, 4, 8])
|
||||
k = paddle.randn([1, 2, 4, 8])
|
||||
v = paddle.randn([1, 2, 4, 8])
|
||||
out = self.backend._scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
self.assertEqual(list(out.shape), [1, 2, 4, 8])
|
||||
|
||||
def test_run_sdpa_forward_extend(self):
|
||||
out = paddle.zeros_like(self.k)
|
||||
try:
|
||||
out = self.backend._run_sdpa_forward_extend(
|
||||
self.q.reshape([8, 2, 8]),
|
||||
out,
|
||||
self.k,
|
||||
self.v,
|
||||
self.forward_meta.req_to_token_pool.req_to_token,
|
||||
self.forward_meta.req_pool_indices,
|
||||
self.forward_meta.seq_lens,
|
||||
self.forward_meta.extend_prefix_lens,
|
||||
self.forward_meta.extend_seq_lens,
|
||||
causal=False,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_forward_extend(self):
|
||||
try:
|
||||
o = self.backend.forward_extend(self.q, self.k, self.v, self.layer, self.forward_meta)
|
||||
self.assertEqual(list(o.shape), list(self.q.shape))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_forward_decode(self):
|
||||
try:
|
||||
o = self.backend.forward_decode(self.q, self.k, self.v, self.layer, self.forward_meta)
|
||||
self.assertEqual(list(o.shape), list(self.q.shape))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
31
tests/platforms/test_utils.py
Normal file
31
tests/platforms/test_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
|
||||
|
||||
|
||||
class TestConvertToNpuDequantScale(unittest.TestCase):
|
||||
|
||||
def test_npu_not_available(self):
|
||||
with patch("paddle.is_compiled_with_custom_device", return_value=False):
|
||||
x = paddle.to_tensor([1.0, 2.0, 3.0], dtype=paddle.float32)
|
||||
out = convert_to_npu_dequant_scale(x)
|
||||
self.assertTrue((out.numpy() == x.numpy()).all())
|
||||
|
||||
def test_npu_available(self):
|
||||
with patch("paddle.is_compiled_with_custom_device", return_value=True):
|
||||
x = paddle.to_tensor([1, 2, 3], dtype=paddle.float32)
|
||||
out = convert_to_npu_dequant_scale(x)
|
||||
self.assertEqual(out.dtype, paddle.int64)
|
||||
# Verify scaled output matches expected NPU dequantization format
|
||||
arr = x.numpy()
|
||||
new_deq_scale = np.stack([arr.reshape(-1, 1), np.zeros_like(arr).reshape(-1, 1)], axis=-1).reshape(-1)
|
||||
expected = np.frombuffer(new_deq_scale.tobytes(), dtype=np.int64)
|
||||
self.assertTrue((out.numpy() == expected).all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
62
tests/quantization/test_w4a8.py
Normal file
62
tests/quantization/test_w4a8.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
class TestW4A8Config(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = W4A8Config(is_permuted=False, hadamard_block_size=128)
|
||||
|
||||
def test_name(self):
|
||||
"""Test name() method"""
|
||||
self.assertEqual(self.config.name(), "w4a8")
|
||||
|
||||
def test_from_config_defaults(self):
|
||||
"""Test from_config with empty dict uses defaults"""
|
||||
cfg = W4A8Config.from_config({})
|
||||
self.assertTrue(cfg.is_permuted)
|
||||
self.assertEqual(cfg.hadamard_block_size, 128)
|
||||
|
||||
def test_from_config_full(self):
|
||||
"""Test from_config with full dict"""
|
||||
cfg = W4A8Config.from_config({"is_permuted": False, "hadamard_block_size": 64})
|
||||
self.assertFalse(cfg.is_permuted)
|
||||
self.assertEqual(cfg.hadamard_block_size, 64)
|
||||
|
||||
def test_get_quant_method_cuda(self):
|
||||
"""Test get_quant_method returns CUDA method when on CUDA platform"""
|
||||
with (
|
||||
mock.patch.object(current_platform, "is_cuda", return_value=True),
|
||||
mock.patch(
|
||||
"fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend.CutlassW4A8MoEMethod"
|
||||
) as mock_cuda,
|
||||
):
|
||||
layer = mock.Mock()
|
||||
method = self.config.get_quant_method(layer)
|
||||
mock_cuda.assert_called_once_with(self.config)
|
||||
self.assertEqual(method, mock_cuda.return_value)
|
||||
|
||||
@unittest.skipIf(not hasattr(current_platform, "is_xpu") or not current_platform.is_xpu(), "No XPU, skip test")
|
||||
def test_get_quant_method_xpu(self):
|
||||
"""Test get_quant_method returns XPU method when on XPU platform"""
|
||||
with mock.patch("fastdeploy.model_executor.layers.backends.xpu.moe.fused_moe.XPUW4A8MoEMethod") as mock_xpu:
|
||||
layer = mock.Mock()
|
||||
method = self.config.get_quant_method(layer)
|
||||
mock_xpu.assert_called_once_with(self.config)
|
||||
self.assertEqual(method, mock_xpu.return_value)
|
||||
|
||||
def test_get_quant_method_unsupported(self):
|
||||
"""Test that unsupported platform raises ValueError"""
|
||||
with (
|
||||
mock.patch.object(current_platform, "is_cuda", return_value=False),
|
||||
mock.patch.object(current_platform, "is_xpu", return_value=False),
|
||||
):
|
||||
layer = mock.Mock()
|
||||
with self.assertRaises(ValueError):
|
||||
self.config.get_quant_method(layer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
142
tests/quantization/test_w4afp8.py
Normal file
142
tests/quantization/test_w4afp8.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.w4afp8 import (
|
||||
QUANT_SCALING_FACTOR,
|
||||
W4AFP8Config,
|
||||
W4AFP8LinearMethod,
|
||||
)
|
||||
|
||||
|
||||
class TestW4AFP8(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = W4AFP8Config(
|
||||
weight_scale_dict={"layer.weight_scale": 1.0},
|
||||
act_scale_dict={"layer.activation_scale": 1.0},
|
||||
is_permuted=False,
|
||||
hadamard_block_size=128,
|
||||
)
|
||||
self.method = W4AFP8LinearMethod(self.config)
|
||||
|
||||
# Mock layer
|
||||
self.layer = mock.Mock()
|
||||
self.layer.weight_shape = [8, 4]
|
||||
self.layer.create_parameter.return_value = "created_weight"
|
||||
self.layer.bias = "bias"
|
||||
self.layer.add_bias = True
|
||||
self.layer._dtype = "float16"
|
||||
self.layer.prefix = "layer"
|
||||
|
||||
def test_name(self):
|
||||
self.assertEqual(self.config.name(), "w4afp8")
|
||||
|
||||
def test_from_config_defaults(self):
|
||||
cfg = W4AFP8Config.from_config({})
|
||||
self.assertTrue(cfg.is_permuted)
|
||||
self.assertEqual(cfg.hadamard_block_size, 128)
|
||||
|
||||
def test_from_config_full(self):
|
||||
cfg = W4AFP8Config.from_config(
|
||||
{
|
||||
"weight_scale_dict": {"a": 1},
|
||||
"act_scale_dict": {"b": 2},
|
||||
"is_permuted": False,
|
||||
"hadamard_block_size": 64,
|
||||
}
|
||||
)
|
||||
self.assertEqual(cfg.weight_scale_dict["a"], 1)
|
||||
self.assertEqual(cfg.hadamard_block_size, 64)
|
||||
self.assertFalse(cfg.is_permuted)
|
||||
|
||||
def test_get_quant_method_linear(self):
|
||||
# Non-FusedMoE path
|
||||
method = self.config.get_quant_method(mock.Mock())
|
||||
self.assertIsInstance(method, W4AFP8LinearMethod)
|
||||
|
||||
@mock.patch("fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend.CutlassW4AFP8MoEMethod")
|
||||
def test_get_quant_method_moe(self, mock_cutlass):
|
||||
# Mock FusedMoE instance
|
||||
layer = mock.Mock(spec=FusedMoE)
|
||||
type(layer).return_value = None
|
||||
result = self.config.get_quant_method(layer)
|
||||
|
||||
mock_cutlass.assert_called_once_with(self.config)
|
||||
self.assertEqual(result, mock_cutlass.return_value)
|
||||
|
||||
def test_create_weights(self):
|
||||
original_shape = [8, 4]
|
||||
self.layer.weight_shape = original_shape.copy()
|
||||
|
||||
self.method.create_weights(self.layer)
|
||||
self.assertEqual(self.layer.weight_dtype, "int8")
|
||||
self.assertEqual(self.layer.weight, "created_weight")
|
||||
self.assertEqual(self.layer.weight_shape, [2, 8])
|
||||
|
||||
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize")
|
||||
@mock.patch("paddle.view")
|
||||
@mock.patch("paddle.cast")
|
||||
def test_process_loaded_weights(self, mock_cast, mock_view, mock_quant):
|
||||
mock_cast.return_value.cpu.return_value = "cpu_tensor"
|
||||
mock_quant.return_value = ("quanted_weight", "weight_scale")
|
||||
mock_view.return_value = "reshaped_scale"
|
||||
|
||||
self.layer.weight = mock.Mock()
|
||||
self.layer.weight_scale = mock.Mock()
|
||||
|
||||
self.method.process_loaded_weights(self.layer, "weights")
|
||||
|
||||
mock_cast.assert_called_once_with("weights", "float32")
|
||||
mock_quant.assert_called_once()
|
||||
mock_view.assert_called_once_with("weight_scale", self.layer._dtype)
|
||||
self.layer.weight.set_value.assert_called_once_with("quanted_weight")
|
||||
self.layer.weight_scale.set_value.assert_called_once_with("reshaped_scale")
|
||||
|
||||
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16_weight_quantize")
|
||||
@mock.patch("paddle.view")
|
||||
@mock.patch("paddle.cast")
|
||||
def test_process_loaded_weights_with_error(self, mock_cast, mock_view, mock_quant):
|
||||
mock_cast.return_value.cpu.return_value = "cpu_tensor"
|
||||
mock_quant.return_value = (None, None)
|
||||
self.layer.weight = mock.Mock()
|
||||
self.layer.weight_scale = mock.Mock()
|
||||
|
||||
self.method.process_loaded_weights(self.layer, "weights")
|
||||
|
||||
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16")
|
||||
def test_apply_with_bias(self, mock_gemm):
|
||||
mock_gemm.return_value = "output"
|
||||
x = mock.Mock()
|
||||
self.layer.weight = "w"
|
||||
self.layer.weight_scale = "s"
|
||||
|
||||
result = self.method.apply(self.layer, x)
|
||||
mock_gemm.assert_called_once()
|
||||
self.assertEqual(result, "output")
|
||||
|
||||
# Verify out_scale value
|
||||
call_args = mock_gemm.call_args.kwargs
|
||||
expected_out_scale = 1.0 / (1.0 * QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR)
|
||||
self.assertAlmostEqual(call_args["out_scale"], expected_out_scale)
|
||||
|
||||
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16")
|
||||
def test_apply_without_bias(self, mock_gemm):
|
||||
self.layer.add_bias = False
|
||||
mock_gemm.return_value = "out"
|
||||
x = "x"
|
||||
|
||||
result = self.method.apply(self.layer, x)
|
||||
self.assertEqual(result, "out")
|
||||
args = mock_gemm.call_args.kwargs
|
||||
self.assertIsNone(args["bias"])
|
||||
|
||||
@mock.patch("fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16")
|
||||
def test_apply_prefix_missing_key(self, mock_gemm):
|
||||
self.layer.prefix = "unknown"
|
||||
x = mock.Mock()
|
||||
with self.assertRaises(TypeError):
|
||||
self.method.apply(self.layer, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user