mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[CI] Add Unittest (#5187)
* add test * Delete tests/model_executor/test_w4afp8.py * Rename test_utils.py to test_tool_parsers_utils.py * add test * add test * fix platforms * Delete tests/cache_manager/test_platforms.py * dont change Removed copyright notice and license information.
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from fastdeploy.entrypoints.openai.tool_parsers import utils
|
||||
|
||||
|
||||
class TestPartialJsonUtils(unittest.TestCase):
|
||||
"""Unit test suite for partial JSON utility functions."""
|
||||
|
||||
def test_find_common_prefix(self):
|
||||
"""Test common prefix detection between two strings."""
|
||||
string1 = '{"fruit": "ap"}'
|
||||
string2 = '{"fruit": "apple"}'
|
||||
self.assertEqual(utils.find_common_prefix(string1, string2), '{"fruit": "ap')
|
||||
|
||||
def test_find_common_suffix(self):
|
||||
"""Test common suffix detection between two strings."""
|
||||
string1 = '{"fruit": "ap"}'
|
||||
string2 = '{"fruit": "apple"}'
|
||||
self.assertEqual(utils.find_common_suffix(string1, string2), '"}')
|
||||
|
||||
def test_extract_intermediate_diff(self):
|
||||
"""Test extraction of intermediate difference between current and old strings."""
|
||||
old_string = '{"fruit": "ap"}'
|
||||
current_string = '{"fruit": "apple"}'
|
||||
self.assertEqual(utils.extract_intermediate_diff(current_string, old_string), "ple")
|
||||
|
||||
def test_find_all_indices(self):
|
||||
"""Test finding all occurrence indices of a substring in a string."""
|
||||
target_string = "banana"
|
||||
substring = "an"
|
||||
self.assertEqual(utils.find_all_indices(target_string, substring), [1, 3])
|
||||
|
||||
def test_partial_json_loads_complete(self):
|
||||
"""Test partial_json_loads with a complete JSON string."""
|
||||
input_json = '{"a": 1, "b": 2}'
|
||||
parse_flags = Allow.ALL
|
||||
parsed_obj, parsed_length = utils.partial_json_loads(input_json, parse_flags)
|
||||
self.assertEqual(parsed_obj, {"a": 1, "b": 2})
|
||||
self.assertEqual(parsed_length, len(input_json))
|
||||
|
||||
def test_is_complete_json(self):
|
||||
"""Test JSON completeness check."""
|
||||
self.assertTrue(utils.is_complete_json('{"a": 1}'))
|
||||
self.assertFalse(utils.is_complete_json('{"a": 1'))
|
||||
|
||||
def test_consume_space(self):
|
||||
"""Test whitespace consumption from the start of a string."""
|
||||
input_string = " \t\nabc"
|
||||
# 3 spaces + 1 tab + 1 newline = 5 whitespace characters
|
||||
first_non_whitespace_idx = utils.consume_space(0, input_string)
|
||||
self.assertEqual(first_non_whitespace_idx, 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
111
tests/model_executor/test_tensor_wise_fp8.py
Normal file
111
tests/model_executor/test_tensor_wise_fp8.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.tensor_wise_fp8 import (
|
||||
TensorWiseFP8Config,
|
||||
TensorWiseFP8LinearMethod,
|
||||
)
|
||||
|
||||
|
||||
# Dummy classes for test
|
||||
class DummyLayer:
|
||||
"""Dummy linear layer for test purposes"""
|
||||
|
||||
def __init__(self):
|
||||
self.weight_shape = [4, 8]
|
||||
self.weight_key = "weight"
|
||||
self.weight_scale_key = "weight_scale"
|
||||
self.act_scale_key = "act_scale"
|
||||
self.weight_dtype = "float32"
|
||||
self.weight = MagicMock() # Mock weight to avoid dtype copy errors
|
||||
|
||||
def create_parameter(self, shape, dtype, is_bias=False, default_initializer=None):
|
||||
"""Mock parameter creation"""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class DummyFusedMoE:
|
||||
"""Dummy FusedMoE class for patching"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestTensorWiseFP8Config(unittest.TestCase):
|
||||
"""Test suite for TensorWiseFP8Config"""
|
||||
|
||||
def test_get_quant_method_linear(self):
|
||||
"""Verify linear layer returns TensorWiseFP8LinearMethod"""
|
||||
cfg = TensorWiseFP8Config()
|
||||
layer = DummyLayer()
|
||||
method = cfg.get_quant_method(layer)
|
||||
self.assertIsInstance(method, TensorWiseFP8LinearMethod)
|
||||
|
||||
def test_get_quant_method_moe(self):
|
||||
"""Verify FusedMoE layer returns valid quant method"""
|
||||
cfg = TensorWiseFP8Config()
|
||||
layer = DummyFusedMoE()
|
||||
with patch("fastdeploy.model_executor.layers.moe.FusedMoE", DummyFusedMoE):
|
||||
method = cfg.get_quant_method(layer)
|
||||
self.assertTrue(hasattr(method, "quant_config"))
|
||||
|
||||
|
||||
class TestTensorWiseFP8LinearMethod(unittest.TestCase):
|
||||
"""Test suite for TensorWiseFP8LinearMethod"""
|
||||
|
||||
def setUp(self):
|
||||
"""Initialize test fixtures"""
|
||||
self.layer = DummyLayer()
|
||||
self.method = TensorWiseFP8LinearMethod(TensorWiseFP8Config())
|
||||
# Initialize scales to avoid apply errors
|
||||
self.method.act_scale = 1.0
|
||||
self.method.total_scale = 1.0
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Verify weight dtype is set to float8_e4m3fn"""
|
||||
self.method.create_weights(self.layer)
|
||||
self.assertEqual(self.layer.weight_dtype, "float8_e4m3fn")
|
||||
|
||||
def test_process_prequanted_weights(self):
|
||||
"""Verify prequantized weights and scales are processed correctly"""
|
||||
self.layer.weight.copy_ = MagicMock()
|
||||
state_dict = {
|
||||
"weight": paddle.randn([8, 4]),
|
||||
"weight_scale": paddle.to_tensor([0.5], dtype="float32"),
|
||||
"act_scale": paddle.to_tensor([2.0], dtype="float32"),
|
||||
}
|
||||
self.method.process_prequanted_weights(self.layer, state_dict)
|
||||
self.assertAlmostEqual(self.method.act_scale, 2.0)
|
||||
self.assertAlmostEqual(self.method.total_scale, 1.0)
|
||||
self.layer.weight.copy_.assert_called_once()
|
||||
|
||||
@patch("fastdeploy.model_executor.ops.gpu.fused_hadamard_quant_fp8", autospec=True)
|
||||
@patch("fastdeploy.model_executor.ops.gpu.cutlass_fp8_fp8_half_gemm_fused", autospec=True)
|
||||
def test_apply(self, mock_gemm, mock_quant):
|
||||
"""Verify apply method executes with mocked ops"""
|
||||
mock_quant.side_effect = lambda x, scale: x
|
||||
mock_gemm.side_effect = lambda x, w, **kwargs: x
|
||||
x = paddle.randn([4, 8])
|
||||
out = self.method.apply(self.layer, x)
|
||||
self.assertTrue((out == x).all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -20,67 +20,209 @@ from unittest.mock import patch
|
||||
from fastdeploy.platforms.base import _Backend
|
||||
from fastdeploy.platforms.cpu import CPUPlatform
|
||||
from fastdeploy.platforms.cuda import CUDAPlatform
|
||||
from fastdeploy.platforms.dcu import DCUPlatform
|
||||
from fastdeploy.platforms.gcu import GCUPlatform
|
||||
from fastdeploy.platforms.intel_hpu import INTEL_HPUPlatform
|
||||
from fastdeploy.platforms.maca import MACAPlatform
|
||||
from fastdeploy.platforms.npu import NPUPlatform
|
||||
from fastdeploy.platforms.xpu import XPUPlatform
|
||||
|
||||
|
||||
class TestCPUPlatform(unittest.TestCase):
|
||||
"""Test suite for CPUPlatform"""
|
||||
|
||||
def setUp(self):
|
||||
self.platform = CPUPlatform()
|
||||
|
||||
@patch("paddle.device.get_device", return_value="cpu")
|
||||
def test_is_cpu_and_available(self, mock_get_device):
|
||||
"""
|
||||
Check hardware type (CPU) and availability
|
||||
"""
|
||||
"""Verify is_cpu() returns True and platform is available"""
|
||||
self.assertTrue(self.platform.is_cpu())
|
||||
self.assertTrue(self.platform.available())
|
||||
|
||||
def test_attention_backend(self):
|
||||
"""CPUPlatform attention_backend should return empty string"""
|
||||
"""Verify get_attention_backend_cls returns empty string for CPU"""
|
||||
self.assertEqual(self.platform.get_attention_backend_cls(None), "")
|
||||
|
||||
|
||||
class TestCUDAPlatform(unittest.TestCase):
|
||||
"""Test suite for CUDAPlatform"""
|
||||
|
||||
def setUp(self):
|
||||
self.platform = CUDAPlatform()
|
||||
|
||||
@patch("paddle.is_compiled_with_cuda", return_value=True)
|
||||
@patch("paddle.device.get_device", return_value="cuda")
|
||||
@patch("paddle.static.cuda_places", return_value=[0])
|
||||
def test_is_cuda_and_available(self, mock_get_device, mock_is_cuda, mock_cuda_places):
|
||||
"""
|
||||
Check hardware type (CUDA) and availability
|
||||
"""
|
||||
def test_is_cuda_and_available(self, mock_cuda_places, mock_is_cuda, mock_get_device):
|
||||
"""Verify is_cuda() returns True and platform is available"""
|
||||
self.assertTrue(self.platform.is_cuda())
|
||||
self.assertTrue(self.platform.available())
|
||||
|
||||
def test_attention_backend_valid(self):
|
||||
"""
|
||||
CUDAPlatform should return correct backend class name for valid backends
|
||||
"""
|
||||
self.assertIn(
|
||||
"PaddleNativeAttnBackend",
|
||||
self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN),
|
||||
)
|
||||
self.assertIn(
|
||||
"AppendAttentionBackend",
|
||||
self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN),
|
||||
)
|
||||
self.assertIn(
|
||||
"MLAAttentionBackend",
|
||||
self.platform.get_attention_backend_cls(_Backend.MLA_ATTN),
|
||||
)
|
||||
self.assertIn(
|
||||
"FlashAttentionBackend",
|
||||
self.platform.get_attention_backend_cls(_Backend.FLASH_ATTN),
|
||||
)
|
||||
"""Verify valid attention backends return correct class names"""
|
||||
self.assertIn("PaddleNativeAttnBackend", self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
|
||||
self.assertIn("AppendAttentionBackend", self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN))
|
||||
self.assertIn("MLAAttentionBackend", self.platform.get_attention_backend_cls(_Backend.MLA_ATTN))
|
||||
self.assertIn("FlashAttentionBackend", self.platform.get_attention_backend_cls(_Backend.FLASH_ATTN))
|
||||
|
||||
def test_attention_backend_invalid(self):
|
||||
"""
|
||||
CUDAPlatform should raise ValueError for invalid backend
|
||||
"""
|
||||
"""Verify invalid backend raises ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
self.platform.get_attention_backend_cls("INVALID_BACKEND")
|
||||
|
||||
|
||||
class TestMACAPlatform(unittest.TestCase):
|
||||
"""Test suite for MACAPlatform"""
|
||||
|
||||
@patch("paddle.static.cuda_places", return_value=[0, 1])
|
||||
def test_available_true(self, mock_cuda_places):
|
||||
"""Verify available() returns True when GPUs exist"""
|
||||
self.assertTrue(MACAPlatform.available())
|
||||
mock_cuda_places.assert_called_once()
|
||||
|
||||
@patch("paddle.static.cuda_places", side_effect=Exception("No GPU"))
|
||||
def test_available_false(self, mock_cuda_places):
|
||||
"""Verify available() returns False when no GPUs"""
|
||||
self.assertFalse(MACAPlatform.available())
|
||||
mock_cuda_places.assert_called_once()
|
||||
|
||||
def test_get_attention_backend_native(self):
|
||||
"""Verify NATIVE_ATTN returns correct backend class"""
|
||||
self.assertIn("PaddleNativeAttnBackend", MACAPlatform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
|
||||
|
||||
def test_get_attention_backend_append(self):
|
||||
"""Verify APPEND_ATTN returns correct backend class"""
|
||||
self.assertIn("FlashAttentionBackend", MACAPlatform.get_attention_backend_cls(_Backend.APPEND_ATTN))
|
||||
|
||||
def test_get_attention_backend_invalid(self):
|
||||
"""Verify invalid backend raises ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
MACAPlatform.get_attention_backend_cls("INVALID_BACKEND")
|
||||
|
||||
|
||||
class TestINTELHPUPlatform(unittest.TestCase):
|
||||
"""Test suite for INTEL_HPUPlatform"""
|
||||
|
||||
@patch("paddle.base.core.get_custom_device_count", return_value=1)
|
||||
def test_available_true(self, mock_get_count):
|
||||
"""Verify available() returns True when HPU exists"""
|
||||
self.assertTrue(INTEL_HPUPlatform.available())
|
||||
mock_get_count.assert_called_with("intel_hpu")
|
||||
|
||||
@patch("paddle.base.core.get_custom_device_count", side_effect=Exception("No HPU"))
|
||||
@patch("fastdeploy.utils.console_logger.warning")
|
||||
def test_available_false(self, mock_logger_warn, mock_get_count):
|
||||
"""Verify available() returns False and warns when no HPU"""
|
||||
self.assertFalse(INTEL_HPUPlatform.available())
|
||||
mock_logger_warn.assert_called()
|
||||
self.assertIn("No HPU", mock_logger_warn.call_args[0][0])
|
||||
|
||||
def test_attention_backend_native(self):
|
||||
"""Verify NATIVE_ATTN returns correct backend class"""
|
||||
self.assertIn("PaddleNativeAttnBackend", INTEL_HPUPlatform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
|
||||
|
||||
def test_attention_backend_hpu(self):
|
||||
"""Verify HPU_ATTN returns correct backend class"""
|
||||
self.assertIn("HPUAttentionBackend", INTEL_HPUPlatform.get_attention_backend_cls(_Backend.HPU_ATTN))
|
||||
|
||||
@patch("fastdeploy.utils.console_logger.warning")
|
||||
def test_attention_backend_other(self, mock_logger_warn):
|
||||
"""Verify invalid backend logs warning and returns None"""
|
||||
self.assertIsNone(INTEL_HPUPlatform.get_attention_backend_cls("INVALID_BACKEND"))
|
||||
mock_logger_warn.assert_called()
|
||||
|
||||
|
||||
class TestNPUPlatform(unittest.TestCase):
|
||||
"""Test suite for NPUPlatform"""
|
||||
|
||||
def setUp(self):
|
||||
self.platform = NPUPlatform()
|
||||
|
||||
def test_device_name(self):
|
||||
"""Verify device_name is set to 'npu'"""
|
||||
self.assertEqual(self.platform.device_name, "npu")
|
||||
|
||||
|
||||
class TestDCUPlatform(unittest.TestCase):
|
||||
"""Test suite for DCUPlatform"""
|
||||
|
||||
def setUp(self):
|
||||
self.platform = DCUPlatform()
|
||||
|
||||
@patch("paddle.static.cuda_places", return_value=[0])
|
||||
def test_available_with_gpu(self, mock_cuda_places):
|
||||
"""Verify available() returns True when GPU exists"""
|
||||
self.assertTrue(self.platform.available())
|
||||
|
||||
@patch("paddle.static.cuda_places", side_effect=Exception("No GPU"))
|
||||
def test_available_no_gpu(self, mock_cuda_places):
|
||||
"""Verify available() returns False when no GPU"""
|
||||
self.assertFalse(self.platform.available())
|
||||
|
||||
def test_attention_backend_native(self):
|
||||
"""Verify NATIVE_ATTN returns correct backend class"""
|
||||
self.assertIn("PaddleNativeAttnBackend", self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
|
||||
|
||||
def test_attention_backend_block(self):
|
||||
"""Verify BLOCK_ATTN returns correct backend class"""
|
||||
self.assertIn("BlockAttentionBackend", self.platform.get_attention_backend_cls(_Backend.BLOCK_ATTN))
|
||||
|
||||
def test_attention_backend_invalid(self):
|
||||
"""Verify invalid backend returns None"""
|
||||
self.assertIsNone(self.platform.get_attention_backend_cls("INVALID_BACKEND"))
|
||||
|
||||
|
||||
class TestGCUPlatform(unittest.TestCase):
|
||||
"""Test suite for GCUPlatform"""
|
||||
|
||||
def setUp(self):
|
||||
self.platform = GCUPlatform()
|
||||
|
||||
@patch("paddle.base.core.get_custom_device_count", return_value=1)
|
||||
def test_available_with_gcu(self, mock_get_count):
|
||||
"""Verify available() returns True when GCU exists"""
|
||||
self.assertTrue(self.platform.available())
|
||||
|
||||
@patch("paddle.base.core.get_custom_device_count", side_effect=Exception("No GCU"))
|
||||
def test_available_no_gcu(self, mock_get_count):
|
||||
"""Verify available() returns False when no GCU"""
|
||||
self.assertFalse(self.platform.available())
|
||||
|
||||
def test_attention_backend_native(self):
|
||||
"""Verify NATIVE_ATTN returns correct backend class"""
|
||||
self.assertIn("GCUMemEfficientAttnBackend", self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
|
||||
|
||||
def test_attention_backend_append(self):
|
||||
"""Verify APPEND_ATTN returns correct backend class"""
|
||||
self.assertIn("GCUFlashAttnBackend", self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN))
|
||||
|
||||
def test_attention_backend_invalid(self):
|
||||
"""Verify invalid backend raises ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
self.platform.get_attention_backend_cls("INVALID_BACKEND")
|
||||
|
||||
|
||||
class TestXPUPlatform(unittest.TestCase):
|
||||
"""Test suite for XPUPlatform"""
|
||||
|
||||
@patch("paddle.is_compiled_with_xpu", return_value=True)
|
||||
@patch("paddle.static.xpu_places", return_value=[0])
|
||||
def test_available_true(self, mock_places, mock_xpu):
|
||||
"""Verify available() returns True when XPU is compiled and available"""
|
||||
self.assertTrue(XPUPlatform.available())
|
||||
|
||||
@patch("paddle.is_compiled_with_xpu", return_value=False)
|
||||
@patch("paddle.static.xpu_places", return_value=[])
|
||||
def test_available_false(self, mock_places, mock_xpu):
|
||||
"""Verify available() returns False when XPU is unavailable"""
|
||||
self.assertFalse(XPUPlatform.available())
|
||||
|
||||
def test_get_attention_backend_cls(self):
|
||||
"""Verify NATIVE_ATTN returns correct XPU backend class"""
|
||||
expected_cls = "fastdeploy.model_executor.layers.attention.XPUAttentionBackend"
|
||||
self.assertEqual(XPUPlatform.get_attention_backend_cls(_Backend.NATIVE_ATTN), expected_cls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
277
tests/scheduler/test_workers.py
Normal file
277
tests/scheduler/test_workers.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from fastdeploy.scheduler.workers import Task, Workers
|
||||
|
||||
|
||||
class TestTask(unittest.TestCase):
|
||||
def test_repr(self):
|
||||
"""Test the __repr__ method of Task class.
|
||||
|
||||
Verifies that the string representation of Task contains key attributes
|
||||
(task_id and reason) with correct values.
|
||||
"""
|
||||
task = Task("123", 456, reason="ok")
|
||||
repr_str = repr(task)
|
||||
self.assertIn("task_id:123", repr_str)
|
||||
self.assertIn("reason:ok", repr_str)
|
||||
|
||||
|
||||
class TestWorkers(unittest.TestCase):
|
||||
"""Unit test suite for the Workers class.
|
||||
|
||||
Covers core functionalities including task processing flow, filtering, unique task addition,
|
||||
timeout handling, exception resilience, and edge cases like empty inputs or zero workers.
|
||||
"""
|
||||
|
||||
def test_basic_flow(self):
|
||||
"""Test basic task processing flow with multiple tasks and workers.
|
||||
|
||||
Verifies that Workers can start multiple worker threads, process batched tasks,
|
||||
and return correct results in expected format.
|
||||
"""
|
||||
|
||||
def simple_work(tasks):
|
||||
"""Simple work function that increments task raw value by 1."""
|
||||
return [Task(task.id, task.raw + 1) for task in tasks]
|
||||
|
||||
workers = Workers("test_basic_flow", work=simple_work, max_task_batch_size=2)
|
||||
workers.start(2)
|
||||
|
||||
tasks = [Task(str(i), i) for i in range(4)]
|
||||
workers.add_tasks(tasks)
|
||||
|
||||
# Collect results with timeout protection
|
||||
results = []
|
||||
start_time = time.time()
|
||||
while len(results) < 4 and time.time() - start_time < 1:
|
||||
batch_results = workers.get_results(10, timeout=0.1)
|
||||
if batch_results:
|
||||
results.extend(batch_results)
|
||||
|
||||
# Clean up resources
|
||||
workers.terminate()
|
||||
|
||||
result_map = {int(task.id): task.raw for task in results}
|
||||
self.assertEqual(result_map, {0: 1, 1: 2, 2: 3, 3: 4})
|
||||
|
||||
def test_task_filters(self):
|
||||
"""Test task filtering functionality.
|
||||
|
||||
Verifies that Workers apply specified task filters correctly and process
|
||||
all eligible tasks without dropping or duplicating.
|
||||
"""
|
||||
|
||||
def work_function(tasks):
|
||||
"""Work function that adds 10 to task raw value."""
|
||||
return [Task(task.id, task.raw + 10) for task in tasks]
|
||||
|
||||
# Define filter functions: even and odd task ID filters
|
||||
def filter_even(task):
|
||||
"""Filter to select tasks with even-numbered IDs."""
|
||||
return int(task.id) % 2 == 0
|
||||
|
||||
def filter_odd(task):
|
||||
"""Filter to select tasks with odd-numbered IDs."""
|
||||
return int(task.id) % 2 == 1
|
||||
|
||||
# Initialize Workers with filter chain and 2 worker threads
|
||||
workers = Workers(
|
||||
"test_task_filters",
|
||||
work=work_function,
|
||||
max_task_batch_size=1,
|
||||
task_filters=[filter_even, filter_odd],
|
||||
)
|
||||
workers.start(2)
|
||||
|
||||
# Add 6 tasks with IDs 0-5
|
||||
workers.add_tasks([Task(str(i), i) for i in range(6)])
|
||||
|
||||
# Collect results with timeout protection
|
||||
results = []
|
||||
start_time = time.time()
|
||||
while len(results) < 6 and time.time() - start_time < 2:
|
||||
batch_results = workers.get_results(10, timeout=0.1)
|
||||
if batch_results:
|
||||
results.extend(batch_results)
|
||||
|
||||
# Clean up resources
|
||||
workers.terminate()
|
||||
|
||||
# Expected task ID groups
|
||||
even_ids = {0, 2, 4}
|
||||
odd_ids = {1, 3, 5}
|
||||
|
||||
# Extract original IDs from results (reverse work function calculation)
|
||||
got_even = {int(task.raw) - 10 for task in results if int(task.id) in even_ids}
|
||||
got_odd = {int(task.raw) - 10 for task in results if int(task.id) in odd_ids}
|
||||
|
||||
# Verify all even and odd tasks were processed correctly
|
||||
self.assertEqual(got_even, even_ids)
|
||||
self.assertEqual(got_odd, odd_ids)
|
||||
|
||||
def test_unique_task_addition(self):
|
||||
"""Test unique task addition functionality.
|
||||
|
||||
Verifies that duplicate tasks (same task_id) are filtered out when unique=True,
|
||||
while new tasks are processed normally.
|
||||
"""
|
||||
|
||||
def slow_work(tasks):
|
||||
"""Slow work function to simulate processing delay (50ms)."""
|
||||
time.sleep(0.05)
|
||||
return [Task(task.id, task.raw + 1) for task in tasks]
|
||||
|
||||
# Initialize Workers with 1 worker thread (to control task processing order)
|
||||
workers = Workers("test_unique_task_addition", work=slow_work, max_task_batch_size=1)
|
||||
workers.start(1)
|
||||
|
||||
# Add first task (task_id="1") with unique=True
|
||||
workers.add_tasks([Task("1", 100)], unique=True)
|
||||
time.sleep(0.02) # Allow task to enter running state
|
||||
|
||||
# Add duplicate task (same task_id="1") - should be filtered out
|
||||
workers.add_tasks([Task("1", 200)], unique=True)
|
||||
# Add new task (task_id="2") - should be processed
|
||||
workers.add_tasks([Task("2", 300)], unique=True)
|
||||
|
||||
# Collect results (expect 2 valid results)
|
||||
results = []
|
||||
start_time = time.time()
|
||||
while len(results) < 2 and time.time() - start_time < 1:
|
||||
batch_results = workers.get_results(10, timeout=0.1)
|
||||
if batch_results:
|
||||
results.extend(batch_results)
|
||||
|
||||
# Clean up resources
|
||||
workers.terminate()
|
||||
|
||||
# Verify only unique task IDs are present
|
||||
result_ids = sorted(int(task.id) for task in results)
|
||||
self.assertEqual(result_ids, [1, 2])
|
||||
|
||||
def test_get_results_timeout(self):
|
||||
"""Test timeout handling in get_results method.
|
||||
|
||||
Verifies that get_results returns empty list after specified timeout when
|
||||
no results are available, and the actual wait time meets the timeout requirement.
|
||||
"""
|
||||
|
||||
def no_result_work(tasks):
|
||||
"""Work function that returns empty list (no results)."""
|
||||
time.sleep(0.01)
|
||||
return []
|
||||
|
||||
# Initialize Workers with 1 worker thread
|
||||
workers = Workers("test_get_results_timeout", work=no_result_work, max_task_batch_size=1)
|
||||
workers.start(1)
|
||||
|
||||
# Measure time taken for get_results with 50ms timeout
|
||||
start_time = time.time()
|
||||
results = workers.get_results(max_size=1, timeout=0.05)
|
||||
end_time = time.time()
|
||||
|
||||
# Clean up resources
|
||||
workers.terminate()
|
||||
|
||||
# Verify no results are returned and timeout is respected
|
||||
self.assertEqual(results, [])
|
||||
self.assertGreaterEqual(end_time - start_time, 0.05)
|
||||
|
||||
def test_start_zero_workers(self):
|
||||
"""Test starting Workers with zero worker threads.
|
||||
|
||||
Verifies that Workers initializes correctly with zero threads and the worker pool is empty.
|
||||
"""
|
||||
# Initialize Workers without specifying max_task_batch_size (uses default)
|
||||
workers = Workers("test_start_zero_workers", work=lambda tasks: tasks)
|
||||
workers.start(0)
|
||||
|
||||
# Verify worker pool is empty
|
||||
self.assertEqual(len(workers.pool), 0)
|
||||
|
||||
def test_worker_exception_resilience(self):
|
||||
"""Test Workers resilience to exceptions in work function.
|
||||
|
||||
Verifies that worker threads continue running (or complete gracefully) when
|
||||
the work function raises an exception, without crashing the entire Workers instance.
|
||||
"""
|
||||
# Track number of work function calls
|
||||
call_tracker = {"count": 0}
|
||||
|
||||
def error_prone_work(tasks):
|
||||
"""Work function that raises RuntimeError on each call."""
|
||||
call_tracker["count"] += 1
|
||||
raise RuntimeError("Simulated work function exception")
|
||||
|
||||
# Initialize Workers with 1 worker thread
|
||||
workers = Workers("test_worker_exception_resilience", work=error_prone_work, max_task_batch_size=1)
|
||||
workers.start(1)
|
||||
|
||||
# Add a test task that will trigger the exception
|
||||
workers.add_tasks([Task("1", 100)])
|
||||
time.sleep(0.05) # Allow time for exception to be raised
|
||||
|
||||
# Clean up resources
|
||||
workers.terminate()
|
||||
|
||||
# Verify work function was called at least once (exception was triggered)
|
||||
self.assertGreaterEqual(call_tracker["count"], 1)
|
||||
|
||||
def test_add_empty_tasks(self):
|
||||
"""Test adding empty task list to Workers.
|
||||
|
||||
Verifies that adding an empty list of tasks does not affect Workers state
|
||||
and no invalid operations are performed.
|
||||
"""
|
||||
# Initialize Workers with 1 worker thread
|
||||
workers = Workers("test_add_empty_tasks", work=lambda tasks: tasks)
|
||||
workers.start(1)
|
||||
|
||||
# Add empty task list
|
||||
workers.add_tasks([])
|
||||
|
||||
# Verify task queue remains empty
|
||||
self.assertEqual(len(workers.tasks), 0)
|
||||
|
||||
# Clean up resources
|
||||
workers.terminate()
|
||||
|
||||
def test_terminate_empty_workers(self):
|
||||
"""Test terminating Workers that have no running threads or tasks.
|
||||
|
||||
Verifies that terminate() can be safely called on an unstarted Workers instance
|
||||
without errors, and all state variables remain in valid initial state.
|
||||
"""
|
||||
# Initialize Workers without starting any threads
|
||||
workers = Workers("test_terminate_empty_workers", work=lambda tasks: tasks)
|
||||
|
||||
# Call terminate on empty Workers
|
||||
workers.terminate()
|
||||
|
||||
# Verify Workers state remains valid
|
||||
self.assertFalse(workers.stop)
|
||||
self.assertEqual(workers.stopped_count, 0)
|
||||
self.assertEqual(len(workers.pool), 0)
|
||||
self.assertEqual(len(workers.tasks), 0)
|
||||
self.assertEqual(len(workers.results), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
223
tests/splitwise/test_internal_adapter_utils.py
Normal file
223
tests/splitwise/test_internal_adapter_utils.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastdeploy.splitwise import internal_adapter_utils as ia
|
||||
|
||||
|
||||
class DummyEngine:
|
||||
"""Dummy Engine class to simulate the actual Engine for testing."""
|
||||
|
||||
class ResourceManager:
|
||||
def available_batch(self):
|
||||
return 4
|
||||
|
||||
def available_block_num(self):
|
||||
return 2
|
||||
|
||||
class Scheduler:
|
||||
def get_unhandled_request_num(self):
|
||||
return 0
|
||||
|
||||
class EngineWorkerQueue:
|
||||
def __init__(self):
|
||||
self.called_task = None
|
||||
|
||||
def put_connect_rdma_task(self, task):
|
||||
self.called_task = task
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
return None
|
||||
|
||||
def __init__(self):
|
||||
self.resource_manager = self.ResourceManager()
|
||||
self.scheduler = self.Scheduler()
|
||||
self.engine_worker_queue = self.EngineWorkerQueue()
|
||||
|
||||
|
||||
class DummyCfg:
|
||||
"""Dummy configuration class to simulate input config for InternalAdapter.
|
||||
|
||||
Contains nested configuration classes (SchedulerConfig, CacheConfig, ModelConfig)
|
||||
with test-friendly default values.
|
||||
"""
|
||||
|
||||
class SchedulerConfig:
|
||||
"""Mock SchedulerConfig with splitwise role configuration."""
|
||||
|
||||
splitwise_role = "single"
|
||||
|
||||
class CacheConfig:
|
||||
"""Mock CacheConfig with cache-related configuration."""
|
||||
|
||||
block_size = 1024
|
||||
total_block_num = 8
|
||||
dec_token_num = 4
|
||||
|
||||
class ModelConfig:
|
||||
"""Mock ModelConfig with model-related configuration."""
|
||||
|
||||
max_model_len = 2048
|
||||
|
||||
# Top-level configuration attributes
|
||||
max_prefill_batch = 2
|
||||
scheduler_config = SchedulerConfig()
|
||||
cache_config = CacheConfig()
|
||||
model_config = ModelConfig()
|
||||
|
||||
|
||||
class TestInternalAdapterBasic(unittest.TestCase):
|
||||
"""
|
||||
Unit test suite for basic functionalities of InternalAdapter.
|
||||
Covers initialization, server info retrieval, and thread creation.
|
||||
"""
|
||||
|
||||
@patch("fastdeploy.splitwise.internal_adapter_utils.ZmqTcpServer")
|
||||
def test_basic_initialization(self, mock_zmq_server):
|
||||
"""Test InternalAdapter initialization and _get_current_server_info method."""
|
||||
# Setup mock ZmqTcpServer instance
|
||||
mock_server_instance = MagicMock()
|
||||
mock_zmq_server.return_value = mock_server_instance
|
||||
|
||||
# Initialize InternalAdapter with dummy config, engine, and dp_rank
|
||||
adapter = ia.InternalAdapter(cfg=DummyCfg(), engine=DummyEngine(), dp_rank=0)
|
||||
|
||||
# Verify _get_current_server_info returns expected structure
|
||||
server_info = adapter._get_current_server_info()
|
||||
expected_keys = ["splitwise_role", "block_size", "available_resource"]
|
||||
for key in expected_keys:
|
||||
with self.subTest(key=key):
|
||||
self.assertIn(key, server_info, f"Server info missing required key: {key}")
|
||||
|
||||
# Verify background threads are properly initialized
|
||||
self.assertTrue(
|
||||
isinstance(adapter.recv_external_instruct_thread, threading.Thread),
|
||||
"recv_external_instruct_thread should be a Thread instance",
|
||||
)
|
||||
self.assertTrue(
|
||||
isinstance(adapter.response_external_instruct_thread, threading.Thread),
|
||||
"response_external_instruct_thread should be a Thread instance",
|
||||
)
|
||||
|
||||
|
||||
class TestInternalAdapterRecvPayload(unittest.TestCase):
|
||||
"""Unit test suite for payload reception functionality of InternalAdapter.
|
||||
|
||||
Covers handling of different control commands (get_payload, get_metrics, connect_rdma)
|
||||
and exception handling.
|
||||
"""
|
||||
|
||||
@patch("fastdeploy.splitwise.internal_adapter_utils.ZmqTcpServer")
|
||||
@patch("fastdeploy.splitwise.internal_adapter_utils.get_filtered_metrics")
|
||||
@patch("fastdeploy.splitwise.internal_adapter_utils.logger")
|
||||
def test_recv_control_cmd_branches(self, mock_logger, mock_get_metrics, mock_zmq_server):
|
||||
"""Test all command handling branches in _recv_external_module_control_instruct."""
|
||||
# Setup mock ZmqTcpServer instance
|
||||
mock_server_instance = MagicMock()
|
||||
mock_zmq_server.return_value = mock_server_instance
|
||||
|
||||
# Create a generator to simulate sequential control commands
|
||||
def control_cmd_generator():
|
||||
"""Generator to yield test commands in sequence."""
|
||||
yield {"task_id": "1", "cmd": "get_payload"}
|
||||
yield {"task_id": "2", "cmd": "get_metrics"}
|
||||
yield {"task_id": "3", "cmd": "connect_rdma"}
|
||||
while True:
|
||||
yield None
|
||||
|
||||
# Configure mock server to return commands from the generator
|
||||
mock_server_instance.recv_control_cmd.side_effect = control_cmd_generator()
|
||||
mock_server_instance.response_for_control_cmd = MagicMock() # Track response calls
|
||||
mock_get_metrics.return_value = "mocked_metrics" # Mock metrics response
|
||||
|
||||
# Initialize dependencies and InternalAdapter
|
||||
test_engine = DummyEngine()
|
||||
adapter = ia.InternalAdapter(cfg=DummyCfg(), engine=test_engine, dp_rank=0)
|
||||
|
||||
# Override _recv_external_module_control_instruct to run only 3 iterations (test all commands)
|
||||
def run_limited_iterations(self):
|
||||
"""Modified method to process 3 commands and exit (avoids infinite loop)."""
|
||||
for _ in range(3):
|
||||
try:
|
||||
# Acquire response lock and receive command
|
||||
with self.response_lock:
|
||||
control_cmd = self.recv_control_cmd_server.recv_control_cmd()
|
||||
|
||||
if control_cmd is None:
|
||||
continue # Skip None commands
|
||||
|
||||
task_id = control_cmd["task_id"]
|
||||
cmd = control_cmd["cmd"]
|
||||
|
||||
# Handle each command type
|
||||
if cmd == "get_payload":
|
||||
payload_info = self._get_current_server_info()
|
||||
response = {"task_id": task_id, "result": payload_info}
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id, response)
|
||||
elif cmd == "get_metrics":
|
||||
metrics_data = mock_get_metrics()
|
||||
response = {"task_id": task_id, "result": metrics_data}
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id, response)
|
||||
elif cmd == "connect_rdma":
|
||||
test_engine.engine_worker_queue.put_connect_rdma_task(control_cmd)
|
||||
except Exception as e:
|
||||
mock_logger.error(f"handle_control_cmd got error: {e}")
|
||||
|
||||
# Bind the modified method to the adapter instance
|
||||
adapter._recv_external_module_control_instruct = run_limited_iterations.__get__(adapter)
|
||||
# Execute the modified method to process test commands
|
||||
adapter._recv_external_module_control_instruct()
|
||||
|
||||
# Verify 'get_payload' and 'get_metrics' triggered responses (2 total calls)
|
||||
self.assertEqual(
|
||||
mock_server_instance.response_for_control_cmd.call_count,
|
||||
2,
|
||||
"response_for_control_cmd should be called twice (get_payload + get_metrics)",
|
||||
)
|
||||
|
||||
# Verify responses were sent for task IDs "1" and "2"
|
||||
called_task_ids = [call_arg[0][0] for call_arg in mock_server_instance.response_for_control_cmd.call_args_list]
|
||||
self.assertIn("1", called_task_ids, "Response not sent for 'get_payload' task (ID: 1)")
|
||||
self.assertIn("2", called_task_ids, "Response not sent for 'get_metrics' task (ID: 2)")
|
||||
|
||||
# Verify 'connect_rdma' task was submitted to EngineWorkerQueue
|
||||
self.assertEqual(
|
||||
test_engine.engine_worker_queue.called_task["task_id"],
|
||||
"3",
|
||||
"connect_rdma task with ID 3 not received by EngineWorkerQueue",
|
||||
)
|
||||
|
||||
# Test exception handling branch
|
||||
def raise_test_exception(self):
|
||||
"""Modified method to raise a test exception."""
|
||||
raise ValueError("test_exception")
|
||||
|
||||
# Configure mock server to trigger exception
|
||||
adapter.recv_control_cmd_server.recv_control_cmd = raise_test_exception.__get__(adapter)
|
||||
# Execute to trigger exception
|
||||
adapter._recv_external_module_control_instruct()
|
||||
|
||||
# Verify exception was logged
|
||||
self.assertTrue(mock_logger.error.called, "Logger should capture exceptions during control command handling")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user