[CI] Add Unittest (#5187)

* add test

* Delete tests/model_executor/test_w4afp8.py

* Rename test_utils.py to test_tool_parsers_utils.py

* add test

* add test

* fix platforms

* Delete tests/cache_manager/test_platforms.py

* dont change 

Removed copyright notice and license information.
This commit is contained in:
Echo-Nie
2025-11-25 11:00:34 +08:00
committed by GitHub
parent 717da50b40
commit a418d7b60b
5 changed files with 856 additions and 30 deletions

View File

@@ -0,0 +1,73 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from partial_json_parser.core.options import Allow
from fastdeploy.entrypoints.openai.tool_parsers import utils
class TestPartialJsonUtils(unittest.TestCase):
"""Unit test suite for partial JSON utility functions."""
def test_find_common_prefix(self):
"""Test common prefix detection between two strings."""
string1 = '{"fruit": "ap"}'
string2 = '{"fruit": "apple"}'
self.assertEqual(utils.find_common_prefix(string1, string2), '{"fruit": "ap')
def test_find_common_suffix(self):
"""Test common suffix detection between two strings."""
string1 = '{"fruit": "ap"}'
string2 = '{"fruit": "apple"}'
self.assertEqual(utils.find_common_suffix(string1, string2), '"}')
def test_extract_intermediate_diff(self):
"""Test extraction of intermediate difference between current and old strings."""
old_string = '{"fruit": "ap"}'
current_string = '{"fruit": "apple"}'
self.assertEqual(utils.extract_intermediate_diff(current_string, old_string), "ple")
def test_find_all_indices(self):
"""Test finding all occurrence indices of a substring in a string."""
target_string = "banana"
substring = "an"
self.assertEqual(utils.find_all_indices(target_string, substring), [1, 3])
def test_partial_json_loads_complete(self):
"""Test partial_json_loads with a complete JSON string."""
input_json = '{"a": 1, "b": 2}'
parse_flags = Allow.ALL
parsed_obj, parsed_length = utils.partial_json_loads(input_json, parse_flags)
self.assertEqual(parsed_obj, {"a": 1, "b": 2})
self.assertEqual(parsed_length, len(input_json))
def test_is_complete_json(self):
"""Test JSON completeness check."""
self.assertTrue(utils.is_complete_json('{"a": 1}'))
self.assertFalse(utils.is_complete_json('{"a": 1'))
def test_consume_space(self):
"""Test whitespace consumption from the start of a string."""
input_string = " \t\nabc"
# 3 spaces + 1 tab + 1 newline = 5 whitespace characters
first_non_whitespace_idx = utils.consume_space(0, input_string)
self.assertEqual(first_non_whitespace_idx, 5)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,111 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import MagicMock, patch
import paddle
from fastdeploy.model_executor.layers.quantization.tensor_wise_fp8 import (
TensorWiseFP8Config,
TensorWiseFP8LinearMethod,
)
# Dummy classes for test
class DummyLayer:
"""Dummy linear layer for test purposes"""
def __init__(self):
self.weight_shape = [4, 8]
self.weight_key = "weight"
self.weight_scale_key = "weight_scale"
self.act_scale_key = "act_scale"
self.weight_dtype = "float32"
self.weight = MagicMock() # Mock weight to avoid dtype copy errors
def create_parameter(self, shape, dtype, is_bias=False, default_initializer=None):
"""Mock parameter creation"""
return MagicMock()
class DummyFusedMoE:
"""Dummy FusedMoE class for patching"""
pass
class TestTensorWiseFP8Config(unittest.TestCase):
"""Test suite for TensorWiseFP8Config"""
def test_get_quant_method_linear(self):
"""Verify linear layer returns TensorWiseFP8LinearMethod"""
cfg = TensorWiseFP8Config()
layer = DummyLayer()
method = cfg.get_quant_method(layer)
self.assertIsInstance(method, TensorWiseFP8LinearMethod)
def test_get_quant_method_moe(self):
"""Verify FusedMoE layer returns valid quant method"""
cfg = TensorWiseFP8Config()
layer = DummyFusedMoE()
with patch("fastdeploy.model_executor.layers.moe.FusedMoE", DummyFusedMoE):
method = cfg.get_quant_method(layer)
self.assertTrue(hasattr(method, "quant_config"))
class TestTensorWiseFP8LinearMethod(unittest.TestCase):
"""Test suite for TensorWiseFP8LinearMethod"""
def setUp(self):
"""Initialize test fixtures"""
self.layer = DummyLayer()
self.method = TensorWiseFP8LinearMethod(TensorWiseFP8Config())
# Initialize scales to avoid apply errors
self.method.act_scale = 1.0
self.method.total_scale = 1.0
def test_create_weights(self):
"""Verify weight dtype is set to float8_e4m3fn"""
self.method.create_weights(self.layer)
self.assertEqual(self.layer.weight_dtype, "float8_e4m3fn")
def test_process_prequanted_weights(self):
"""Verify prequantized weights and scales are processed correctly"""
self.layer.weight.copy_ = MagicMock()
state_dict = {
"weight": paddle.randn([8, 4]),
"weight_scale": paddle.to_tensor([0.5], dtype="float32"),
"act_scale": paddle.to_tensor([2.0], dtype="float32"),
}
self.method.process_prequanted_weights(self.layer, state_dict)
self.assertAlmostEqual(self.method.act_scale, 2.0)
self.assertAlmostEqual(self.method.total_scale, 1.0)
self.layer.weight.copy_.assert_called_once()
@patch("fastdeploy.model_executor.ops.gpu.fused_hadamard_quant_fp8", autospec=True)
@patch("fastdeploy.model_executor.ops.gpu.cutlass_fp8_fp8_half_gemm_fused", autospec=True)
def test_apply(self, mock_gemm, mock_quant):
"""Verify apply method executes with mocked ops"""
mock_quant.side_effect = lambda x, scale: x
mock_gemm.side_effect = lambda x, w, **kwargs: x
x = paddle.randn([4, 8])
out = self.method.apply(self.layer, x)
self.assertTrue((out == x).all())
if __name__ == "__main__":
unittest.main()

View File

@@ -20,67 +20,209 @@ from unittest.mock import patch
from fastdeploy.platforms.base import _Backend
from fastdeploy.platforms.cpu import CPUPlatform
from fastdeploy.platforms.cuda import CUDAPlatform
from fastdeploy.platforms.dcu import DCUPlatform
from fastdeploy.platforms.gcu import GCUPlatform
from fastdeploy.platforms.intel_hpu import INTEL_HPUPlatform
from fastdeploy.platforms.maca import MACAPlatform
from fastdeploy.platforms.npu import NPUPlatform
from fastdeploy.platforms.xpu import XPUPlatform
class TestCPUPlatform(unittest.TestCase):
"""Test suite for CPUPlatform"""
def setUp(self):
self.platform = CPUPlatform()
@patch("paddle.device.get_device", return_value="cpu")
def test_is_cpu_and_available(self, mock_get_device):
"""
Check hardware type (CPU) and availability
"""
"""Verify is_cpu() returns True and platform is available"""
self.assertTrue(self.platform.is_cpu())
self.assertTrue(self.platform.available())
def test_attention_backend(self):
"""CPUPlatform attention_backend should return empty string"""
"""Verify get_attention_backend_cls returns empty string for CPU"""
self.assertEqual(self.platform.get_attention_backend_cls(None), "")
class TestCUDAPlatform(unittest.TestCase):
"""Test suite for CUDAPlatform"""
def setUp(self):
self.platform = CUDAPlatform()
@patch("paddle.is_compiled_with_cuda", return_value=True)
@patch("paddle.device.get_device", return_value="cuda")
@patch("paddle.static.cuda_places", return_value=[0])
def test_is_cuda_and_available(self, mock_get_device, mock_is_cuda, mock_cuda_places):
"""
Check hardware type (CUDA) and availability
"""
def test_is_cuda_and_available(self, mock_cuda_places, mock_is_cuda, mock_get_device):
"""Verify is_cuda() returns True and platform is available"""
self.assertTrue(self.platform.is_cuda())
self.assertTrue(self.platform.available())
def test_attention_backend_valid(self):
"""
CUDAPlatform should return correct backend class name for valid backends
"""
self.assertIn(
"PaddleNativeAttnBackend",
self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN),
)
self.assertIn(
"AppendAttentionBackend",
self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN),
)
self.assertIn(
"MLAAttentionBackend",
self.platform.get_attention_backend_cls(_Backend.MLA_ATTN),
)
self.assertIn(
"FlashAttentionBackend",
self.platform.get_attention_backend_cls(_Backend.FLASH_ATTN),
)
"""Verify valid attention backends return correct class names"""
self.assertIn("PaddleNativeAttnBackend", self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
self.assertIn("AppendAttentionBackend", self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN))
self.assertIn("MLAAttentionBackend", self.platform.get_attention_backend_cls(_Backend.MLA_ATTN))
self.assertIn("FlashAttentionBackend", self.platform.get_attention_backend_cls(_Backend.FLASH_ATTN))
def test_attention_backend_invalid(self):
"""
CUDAPlatform should raise ValueError for invalid backend
"""
"""Verify invalid backend raises ValueError"""
with self.assertRaises(ValueError):
self.platform.get_attention_backend_cls("INVALID_BACKEND")
class TestMACAPlatform(unittest.TestCase):
"""Test suite for MACAPlatform"""
@patch("paddle.static.cuda_places", return_value=[0, 1])
def test_available_true(self, mock_cuda_places):
"""Verify available() returns True when GPUs exist"""
self.assertTrue(MACAPlatform.available())
mock_cuda_places.assert_called_once()
@patch("paddle.static.cuda_places", side_effect=Exception("No GPU"))
def test_available_false(self, mock_cuda_places):
"""Verify available() returns False when no GPUs"""
self.assertFalse(MACAPlatform.available())
mock_cuda_places.assert_called_once()
def test_get_attention_backend_native(self):
"""Verify NATIVE_ATTN returns correct backend class"""
self.assertIn("PaddleNativeAttnBackend", MACAPlatform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
def test_get_attention_backend_append(self):
"""Verify APPEND_ATTN returns correct backend class"""
self.assertIn("FlashAttentionBackend", MACAPlatform.get_attention_backend_cls(_Backend.APPEND_ATTN))
def test_get_attention_backend_invalid(self):
"""Verify invalid backend raises ValueError"""
with self.assertRaises(ValueError):
MACAPlatform.get_attention_backend_cls("INVALID_BACKEND")
class TestINTELHPUPlatform(unittest.TestCase):
"""Test suite for INTEL_HPUPlatform"""
@patch("paddle.base.core.get_custom_device_count", return_value=1)
def test_available_true(self, mock_get_count):
"""Verify available() returns True when HPU exists"""
self.assertTrue(INTEL_HPUPlatform.available())
mock_get_count.assert_called_with("intel_hpu")
@patch("paddle.base.core.get_custom_device_count", side_effect=Exception("No HPU"))
@patch("fastdeploy.utils.console_logger.warning")
def test_available_false(self, mock_logger_warn, mock_get_count):
"""Verify available() returns False and warns when no HPU"""
self.assertFalse(INTEL_HPUPlatform.available())
mock_logger_warn.assert_called()
self.assertIn("No HPU", mock_logger_warn.call_args[0][0])
def test_attention_backend_native(self):
"""Verify NATIVE_ATTN returns correct backend class"""
self.assertIn("PaddleNativeAttnBackend", INTEL_HPUPlatform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
def test_attention_backend_hpu(self):
"""Verify HPU_ATTN returns correct backend class"""
self.assertIn("HPUAttentionBackend", INTEL_HPUPlatform.get_attention_backend_cls(_Backend.HPU_ATTN))
@patch("fastdeploy.utils.console_logger.warning")
def test_attention_backend_other(self, mock_logger_warn):
"""Verify invalid backend logs warning and returns None"""
self.assertIsNone(INTEL_HPUPlatform.get_attention_backend_cls("INVALID_BACKEND"))
mock_logger_warn.assert_called()
class TestNPUPlatform(unittest.TestCase):
"""Test suite for NPUPlatform"""
def setUp(self):
self.platform = NPUPlatform()
def test_device_name(self):
"""Verify device_name is set to 'npu'"""
self.assertEqual(self.platform.device_name, "npu")
class TestDCUPlatform(unittest.TestCase):
"""Test suite for DCUPlatform"""
def setUp(self):
self.platform = DCUPlatform()
@patch("paddle.static.cuda_places", return_value=[0])
def test_available_with_gpu(self, mock_cuda_places):
"""Verify available() returns True when GPU exists"""
self.assertTrue(self.platform.available())
@patch("paddle.static.cuda_places", side_effect=Exception("No GPU"))
def test_available_no_gpu(self, mock_cuda_places):
"""Verify available() returns False when no GPU"""
self.assertFalse(self.platform.available())
def test_attention_backend_native(self):
"""Verify NATIVE_ATTN returns correct backend class"""
self.assertIn("PaddleNativeAttnBackend", self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
def test_attention_backend_block(self):
"""Verify BLOCK_ATTN returns correct backend class"""
self.assertIn("BlockAttentionBackend", self.platform.get_attention_backend_cls(_Backend.BLOCK_ATTN))
def test_attention_backend_invalid(self):
"""Verify invalid backend returns None"""
self.assertIsNone(self.platform.get_attention_backend_cls("INVALID_BACKEND"))
class TestGCUPlatform(unittest.TestCase):
"""Test suite for GCUPlatform"""
def setUp(self):
self.platform = GCUPlatform()
@patch("paddle.base.core.get_custom_device_count", return_value=1)
def test_available_with_gcu(self, mock_get_count):
"""Verify available() returns True when GCU exists"""
self.assertTrue(self.platform.available())
@patch("paddle.base.core.get_custom_device_count", side_effect=Exception("No GCU"))
def test_available_no_gcu(self, mock_get_count):
"""Verify available() returns False when no GCU"""
self.assertFalse(self.platform.available())
def test_attention_backend_native(self):
"""Verify NATIVE_ATTN returns correct backend class"""
self.assertIn("GCUMemEfficientAttnBackend", self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN))
def test_attention_backend_append(self):
"""Verify APPEND_ATTN returns correct backend class"""
self.assertIn("GCUFlashAttnBackend", self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN))
def test_attention_backend_invalid(self):
"""Verify invalid backend raises ValueError"""
with self.assertRaises(ValueError):
self.platform.get_attention_backend_cls("INVALID_BACKEND")
class TestXPUPlatform(unittest.TestCase):
"""Test suite for XPUPlatform"""
@patch("paddle.is_compiled_with_xpu", return_value=True)
@patch("paddle.static.xpu_places", return_value=[0])
def test_available_true(self, mock_places, mock_xpu):
"""Verify available() returns True when XPU is compiled and available"""
self.assertTrue(XPUPlatform.available())
@patch("paddle.is_compiled_with_xpu", return_value=False)
@patch("paddle.static.xpu_places", return_value=[])
def test_available_false(self, mock_places, mock_xpu):
"""Verify available() returns False when XPU is unavailable"""
self.assertFalse(XPUPlatform.available())
def test_get_attention_backend_cls(self):
"""Verify NATIVE_ATTN returns correct XPU backend class"""
expected_cls = "fastdeploy.model_executor.layers.attention.XPUAttentionBackend"
self.assertEqual(XPUPlatform.get_attention_backend_cls(_Backend.NATIVE_ATTN), expected_cls)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,277 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import unittest
from fastdeploy.scheduler.workers import Task, Workers
class TestTask(unittest.TestCase):
def test_repr(self):
"""Test the __repr__ method of Task class.
Verifies that the string representation of Task contains key attributes
(task_id and reason) with correct values.
"""
task = Task("123", 456, reason="ok")
repr_str = repr(task)
self.assertIn("task_id:123", repr_str)
self.assertIn("reason:ok", repr_str)
class TestWorkers(unittest.TestCase):
"""Unit test suite for the Workers class.
Covers core functionalities including task processing flow, filtering, unique task addition,
timeout handling, exception resilience, and edge cases like empty inputs or zero workers.
"""
def test_basic_flow(self):
"""Test basic task processing flow with multiple tasks and workers.
Verifies that Workers can start multiple worker threads, process batched tasks,
and return correct results in expected format.
"""
def simple_work(tasks):
"""Simple work function that increments task raw value by 1."""
return [Task(task.id, task.raw + 1) for task in tasks]
workers = Workers("test_basic_flow", work=simple_work, max_task_batch_size=2)
workers.start(2)
tasks = [Task(str(i), i) for i in range(4)]
workers.add_tasks(tasks)
# Collect results with timeout protection
results = []
start_time = time.time()
while len(results) < 4 and time.time() - start_time < 1:
batch_results = workers.get_results(10, timeout=0.1)
if batch_results:
results.extend(batch_results)
# Clean up resources
workers.terminate()
result_map = {int(task.id): task.raw for task in results}
self.assertEqual(result_map, {0: 1, 1: 2, 2: 3, 3: 4})
def test_task_filters(self):
"""Test task filtering functionality.
Verifies that Workers apply specified task filters correctly and process
all eligible tasks without dropping or duplicating.
"""
def work_function(tasks):
"""Work function that adds 10 to task raw value."""
return [Task(task.id, task.raw + 10) for task in tasks]
# Define filter functions: even and odd task ID filters
def filter_even(task):
"""Filter to select tasks with even-numbered IDs."""
return int(task.id) % 2 == 0
def filter_odd(task):
"""Filter to select tasks with odd-numbered IDs."""
return int(task.id) % 2 == 1
# Initialize Workers with filter chain and 2 worker threads
workers = Workers(
"test_task_filters",
work=work_function,
max_task_batch_size=1,
task_filters=[filter_even, filter_odd],
)
workers.start(2)
# Add 6 tasks with IDs 0-5
workers.add_tasks([Task(str(i), i) for i in range(6)])
# Collect results with timeout protection
results = []
start_time = time.time()
while len(results) < 6 and time.time() - start_time < 2:
batch_results = workers.get_results(10, timeout=0.1)
if batch_results:
results.extend(batch_results)
# Clean up resources
workers.terminate()
# Expected task ID groups
even_ids = {0, 2, 4}
odd_ids = {1, 3, 5}
# Extract original IDs from results (reverse work function calculation)
got_even = {int(task.raw) - 10 for task in results if int(task.id) in even_ids}
got_odd = {int(task.raw) - 10 for task in results if int(task.id) in odd_ids}
# Verify all even and odd tasks were processed correctly
self.assertEqual(got_even, even_ids)
self.assertEqual(got_odd, odd_ids)
def test_unique_task_addition(self):
"""Test unique task addition functionality.
Verifies that duplicate tasks (same task_id) are filtered out when unique=True,
while new tasks are processed normally.
"""
def slow_work(tasks):
"""Slow work function to simulate processing delay (50ms)."""
time.sleep(0.05)
return [Task(task.id, task.raw + 1) for task in tasks]
# Initialize Workers with 1 worker thread (to control task processing order)
workers = Workers("test_unique_task_addition", work=slow_work, max_task_batch_size=1)
workers.start(1)
# Add first task (task_id="1") with unique=True
workers.add_tasks([Task("1", 100)], unique=True)
time.sleep(0.02) # Allow task to enter running state
# Add duplicate task (same task_id="1") - should be filtered out
workers.add_tasks([Task("1", 200)], unique=True)
# Add new task (task_id="2") - should be processed
workers.add_tasks([Task("2", 300)], unique=True)
# Collect results (expect 2 valid results)
results = []
start_time = time.time()
while len(results) < 2 and time.time() - start_time < 1:
batch_results = workers.get_results(10, timeout=0.1)
if batch_results:
results.extend(batch_results)
# Clean up resources
workers.terminate()
# Verify only unique task IDs are present
result_ids = sorted(int(task.id) for task in results)
self.assertEqual(result_ids, [1, 2])
def test_get_results_timeout(self):
"""Test timeout handling in get_results method.
Verifies that get_results returns empty list after specified timeout when
no results are available, and the actual wait time meets the timeout requirement.
"""
def no_result_work(tasks):
"""Work function that returns empty list (no results)."""
time.sleep(0.01)
return []
# Initialize Workers with 1 worker thread
workers = Workers("test_get_results_timeout", work=no_result_work, max_task_batch_size=1)
workers.start(1)
# Measure time taken for get_results with 50ms timeout
start_time = time.time()
results = workers.get_results(max_size=1, timeout=0.05)
end_time = time.time()
# Clean up resources
workers.terminate()
# Verify no results are returned and timeout is respected
self.assertEqual(results, [])
self.assertGreaterEqual(end_time - start_time, 0.05)
def test_start_zero_workers(self):
"""Test starting Workers with zero worker threads.
Verifies that Workers initializes correctly with zero threads and the worker pool is empty.
"""
# Initialize Workers without specifying max_task_batch_size (uses default)
workers = Workers("test_start_zero_workers", work=lambda tasks: tasks)
workers.start(0)
# Verify worker pool is empty
self.assertEqual(len(workers.pool), 0)
def test_worker_exception_resilience(self):
"""Test Workers resilience to exceptions in work function.
Verifies that worker threads continue running (or complete gracefully) when
the work function raises an exception, without crashing the entire Workers instance.
"""
# Track number of work function calls
call_tracker = {"count": 0}
def error_prone_work(tasks):
"""Work function that raises RuntimeError on each call."""
call_tracker["count"] += 1
raise RuntimeError("Simulated work function exception")
# Initialize Workers with 1 worker thread
workers = Workers("test_worker_exception_resilience", work=error_prone_work, max_task_batch_size=1)
workers.start(1)
# Add a test task that will trigger the exception
workers.add_tasks([Task("1", 100)])
time.sleep(0.05) # Allow time for exception to be raised
# Clean up resources
workers.terminate()
# Verify work function was called at least once (exception was triggered)
self.assertGreaterEqual(call_tracker["count"], 1)
def test_add_empty_tasks(self):
"""Test adding empty task list to Workers.
Verifies that adding an empty list of tasks does not affect Workers state
and no invalid operations are performed.
"""
# Initialize Workers with 1 worker thread
workers = Workers("test_add_empty_tasks", work=lambda tasks: tasks)
workers.start(1)
# Add empty task list
workers.add_tasks([])
# Verify task queue remains empty
self.assertEqual(len(workers.tasks), 0)
# Clean up resources
workers.terminate()
def test_terminate_empty_workers(self):
"""Test terminating Workers that have no running threads or tasks.
Verifies that terminate() can be safely called on an unstarted Workers instance
without errors, and all state variables remain in valid initial state.
"""
# Initialize Workers without starting any threads
workers = Workers("test_terminate_empty_workers", work=lambda tasks: tasks)
# Call terminate on empty Workers
workers.terminate()
# Verify Workers state remains valid
self.assertFalse(workers.stop)
self.assertEqual(workers.stopped_count, 0)
self.assertEqual(len(workers.pool), 0)
self.assertEqual(len(workers.tasks), 0)
self.assertEqual(len(workers.results), 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,223 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import unittest
from unittest.mock import MagicMock, patch
from fastdeploy.splitwise import internal_adapter_utils as ia
class DummyEngine:
"""Dummy Engine class to simulate the actual Engine for testing."""
class ResourceManager:
def available_batch(self):
return 4
def available_block_num(self):
return 2
class Scheduler:
def get_unhandled_request_num(self):
return 0
class EngineWorkerQueue:
def __init__(self):
self.called_task = None
def put_connect_rdma_task(self, task):
self.called_task = task
def get_connect_rdma_task_response(self):
return None
def __init__(self):
self.resource_manager = self.ResourceManager()
self.scheduler = self.Scheduler()
self.engine_worker_queue = self.EngineWorkerQueue()
class DummyCfg:
"""Dummy configuration class to simulate input config for InternalAdapter.
Contains nested configuration classes (SchedulerConfig, CacheConfig, ModelConfig)
with test-friendly default values.
"""
class SchedulerConfig:
"""Mock SchedulerConfig with splitwise role configuration."""
splitwise_role = "single"
class CacheConfig:
"""Mock CacheConfig with cache-related configuration."""
block_size = 1024
total_block_num = 8
dec_token_num = 4
class ModelConfig:
"""Mock ModelConfig with model-related configuration."""
max_model_len = 2048
# Top-level configuration attributes
max_prefill_batch = 2
scheduler_config = SchedulerConfig()
cache_config = CacheConfig()
model_config = ModelConfig()
class TestInternalAdapterBasic(unittest.TestCase):
"""
Unit test suite for basic functionalities of InternalAdapter.
Covers initialization, server info retrieval, and thread creation.
"""
@patch("fastdeploy.splitwise.internal_adapter_utils.ZmqTcpServer")
def test_basic_initialization(self, mock_zmq_server):
"""Test InternalAdapter initialization and _get_current_server_info method."""
# Setup mock ZmqTcpServer instance
mock_server_instance = MagicMock()
mock_zmq_server.return_value = mock_server_instance
# Initialize InternalAdapter with dummy config, engine, and dp_rank
adapter = ia.InternalAdapter(cfg=DummyCfg(), engine=DummyEngine(), dp_rank=0)
# Verify _get_current_server_info returns expected structure
server_info = adapter._get_current_server_info()
expected_keys = ["splitwise_role", "block_size", "available_resource"]
for key in expected_keys:
with self.subTest(key=key):
self.assertIn(key, server_info, f"Server info missing required key: {key}")
# Verify background threads are properly initialized
self.assertTrue(
isinstance(adapter.recv_external_instruct_thread, threading.Thread),
"recv_external_instruct_thread should be a Thread instance",
)
self.assertTrue(
isinstance(adapter.response_external_instruct_thread, threading.Thread),
"response_external_instruct_thread should be a Thread instance",
)
class TestInternalAdapterRecvPayload(unittest.TestCase):
"""Unit test suite for payload reception functionality of InternalAdapter.
Covers handling of different control commands (get_payload, get_metrics, connect_rdma)
and exception handling.
"""
@patch("fastdeploy.splitwise.internal_adapter_utils.ZmqTcpServer")
@patch("fastdeploy.splitwise.internal_adapter_utils.get_filtered_metrics")
@patch("fastdeploy.splitwise.internal_adapter_utils.logger")
def test_recv_control_cmd_branches(self, mock_logger, mock_get_metrics, mock_zmq_server):
"""Test all command handling branches in _recv_external_module_control_instruct."""
# Setup mock ZmqTcpServer instance
mock_server_instance = MagicMock()
mock_zmq_server.return_value = mock_server_instance
# Create a generator to simulate sequential control commands
def control_cmd_generator():
"""Generator to yield test commands in sequence."""
yield {"task_id": "1", "cmd": "get_payload"}
yield {"task_id": "2", "cmd": "get_metrics"}
yield {"task_id": "3", "cmd": "connect_rdma"}
while True:
yield None
# Configure mock server to return commands from the generator
mock_server_instance.recv_control_cmd.side_effect = control_cmd_generator()
mock_server_instance.response_for_control_cmd = MagicMock() # Track response calls
mock_get_metrics.return_value = "mocked_metrics" # Mock metrics response
# Initialize dependencies and InternalAdapter
test_engine = DummyEngine()
adapter = ia.InternalAdapter(cfg=DummyCfg(), engine=test_engine, dp_rank=0)
# Override _recv_external_module_control_instruct to run only 3 iterations (test all commands)
def run_limited_iterations(self):
"""Modified method to process 3 commands and exit (avoids infinite loop)."""
for _ in range(3):
try:
# Acquire response lock and receive command
with self.response_lock:
control_cmd = self.recv_control_cmd_server.recv_control_cmd()
if control_cmd is None:
continue # Skip None commands
task_id = control_cmd["task_id"]
cmd = control_cmd["cmd"]
# Handle each command type
if cmd == "get_payload":
payload_info = self._get_current_server_info()
response = {"task_id": task_id, "result": payload_info}
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id, response)
elif cmd == "get_metrics":
metrics_data = mock_get_metrics()
response = {"task_id": task_id, "result": metrics_data}
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id, response)
elif cmd == "connect_rdma":
test_engine.engine_worker_queue.put_connect_rdma_task(control_cmd)
except Exception as e:
mock_logger.error(f"handle_control_cmd got error: {e}")
# Bind the modified method to the adapter instance
adapter._recv_external_module_control_instruct = run_limited_iterations.__get__(adapter)
# Execute the modified method to process test commands
adapter._recv_external_module_control_instruct()
# Verify 'get_payload' and 'get_metrics' triggered responses (2 total calls)
self.assertEqual(
mock_server_instance.response_for_control_cmd.call_count,
2,
"response_for_control_cmd should be called twice (get_payload + get_metrics)",
)
# Verify responses were sent for task IDs "1" and "2"
called_task_ids = [call_arg[0][0] for call_arg in mock_server_instance.response_for_control_cmd.call_args_list]
self.assertIn("1", called_task_ids, "Response not sent for 'get_payload' task (ID: 1)")
self.assertIn("2", called_task_ids, "Response not sent for 'get_metrics' task (ID: 2)")
# Verify 'connect_rdma' task was submitted to EngineWorkerQueue
self.assertEqual(
test_engine.engine_worker_queue.called_task["task_id"],
"3",
"connect_rdma task with ID 3 not received by EngineWorkerQueue",
)
# Test exception handling branch
def raise_test_exception(self):
"""Modified method to raise a test exception."""
raise ValueError("test_exception")
# Configure mock server to trigger exception
adapter.recv_control_cmd_server.recv_control_cmd = raise_test_exception.__get__(adapter)
# Execute to trigger exception
adapter._recv_external_module_control_instruct()
# Verify exception was logged
self.assertTrue(mock_logger.error.called, "Logger should capture exceptions during control command handling")
if __name__ == "__main__":
unittest.main()