mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* add test * Delete tests/model_executor/test_w4afp8.py * Rename test_utils.py to test_tool_parsers_utils.py * add test * add test * fix platforms * Delete tests/cache_manager/test_platforms.py * dont change Removed copyright notice and license information.
224 lines
8.9 KiB
Python
224 lines
8.9 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import threading
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from fastdeploy.splitwise import internal_adapter_utils as ia
|
|
|
|
|
|
class DummyEngine:
|
|
"""Dummy Engine class to simulate the actual Engine for testing."""
|
|
|
|
class ResourceManager:
|
|
def available_batch(self):
|
|
return 4
|
|
|
|
def available_block_num(self):
|
|
return 2
|
|
|
|
class Scheduler:
|
|
def get_unhandled_request_num(self):
|
|
return 0
|
|
|
|
class EngineWorkerQueue:
|
|
def __init__(self):
|
|
self.called_task = None
|
|
|
|
def put_connect_rdma_task(self, task):
|
|
self.called_task = task
|
|
|
|
def get_connect_rdma_task_response(self):
|
|
return None
|
|
|
|
def __init__(self):
|
|
self.resource_manager = self.ResourceManager()
|
|
self.scheduler = self.Scheduler()
|
|
self.engine_worker_queue = self.EngineWorkerQueue()
|
|
|
|
|
|
class DummyCfg:
|
|
"""Dummy configuration class to simulate input config for InternalAdapter.
|
|
|
|
Contains nested configuration classes (SchedulerConfig, CacheConfig, ModelConfig)
|
|
with test-friendly default values.
|
|
"""
|
|
|
|
class SchedulerConfig:
|
|
"""Mock SchedulerConfig with splitwise role configuration."""
|
|
|
|
splitwise_role = "single"
|
|
|
|
class CacheConfig:
|
|
"""Mock CacheConfig with cache-related configuration."""
|
|
|
|
block_size = 1024
|
|
total_block_num = 8
|
|
dec_token_num = 4
|
|
|
|
class ModelConfig:
|
|
"""Mock ModelConfig with model-related configuration."""
|
|
|
|
max_model_len = 2048
|
|
|
|
# Top-level configuration attributes
|
|
max_prefill_batch = 2
|
|
scheduler_config = SchedulerConfig()
|
|
cache_config = CacheConfig()
|
|
model_config = ModelConfig()
|
|
|
|
|
|
class TestInternalAdapterBasic(unittest.TestCase):
|
|
"""
|
|
Unit test suite for basic functionalities of InternalAdapter.
|
|
Covers initialization, server info retrieval, and thread creation.
|
|
"""
|
|
|
|
@patch("fastdeploy.splitwise.internal_adapter_utils.ZmqTcpServer")
|
|
def test_basic_initialization(self, mock_zmq_server):
|
|
"""Test InternalAdapter initialization and _get_current_server_info method."""
|
|
# Setup mock ZmqTcpServer instance
|
|
mock_server_instance = MagicMock()
|
|
mock_zmq_server.return_value = mock_server_instance
|
|
|
|
# Initialize InternalAdapter with dummy config, engine, and dp_rank
|
|
adapter = ia.InternalAdapter(cfg=DummyCfg(), engine=DummyEngine(), dp_rank=0)
|
|
|
|
# Verify _get_current_server_info returns expected structure
|
|
server_info = adapter._get_current_server_info()
|
|
expected_keys = ["splitwise_role", "block_size", "available_resource"]
|
|
for key in expected_keys:
|
|
with self.subTest(key=key):
|
|
self.assertIn(key, server_info, f"Server info missing required key: {key}")
|
|
|
|
# Verify background threads are properly initialized
|
|
self.assertTrue(
|
|
isinstance(adapter.recv_external_instruct_thread, threading.Thread),
|
|
"recv_external_instruct_thread should be a Thread instance",
|
|
)
|
|
self.assertTrue(
|
|
isinstance(adapter.response_external_instruct_thread, threading.Thread),
|
|
"response_external_instruct_thread should be a Thread instance",
|
|
)
|
|
|
|
|
|
class TestInternalAdapterRecvPayload(unittest.TestCase):
|
|
"""Unit test suite for payload reception functionality of InternalAdapter.
|
|
|
|
Covers handling of different control commands (get_payload, get_metrics, connect_rdma)
|
|
and exception handling.
|
|
"""
|
|
|
|
@patch("fastdeploy.splitwise.internal_adapter_utils.ZmqTcpServer")
|
|
@patch("fastdeploy.splitwise.internal_adapter_utils.get_filtered_metrics")
|
|
@patch("fastdeploy.splitwise.internal_adapter_utils.logger")
|
|
def test_recv_control_cmd_branches(self, mock_logger, mock_get_metrics, mock_zmq_server):
|
|
"""Test all command handling branches in _recv_external_module_control_instruct."""
|
|
# Setup mock ZmqTcpServer instance
|
|
mock_server_instance = MagicMock()
|
|
mock_zmq_server.return_value = mock_server_instance
|
|
|
|
# Create a generator to simulate sequential control commands
|
|
def control_cmd_generator():
|
|
"""Generator to yield test commands in sequence."""
|
|
yield {"task_id": "1", "cmd": "get_payload"}
|
|
yield {"task_id": "2", "cmd": "get_metrics"}
|
|
yield {"task_id": "3", "cmd": "connect_rdma"}
|
|
while True:
|
|
yield None
|
|
|
|
# Configure mock server to return commands from the generator
|
|
mock_server_instance.recv_control_cmd.side_effect = control_cmd_generator()
|
|
mock_server_instance.response_for_control_cmd = MagicMock() # Track response calls
|
|
mock_get_metrics.return_value = "mocked_metrics" # Mock metrics response
|
|
|
|
# Initialize dependencies and InternalAdapter
|
|
test_engine = DummyEngine()
|
|
adapter = ia.InternalAdapter(cfg=DummyCfg(), engine=test_engine, dp_rank=0)
|
|
|
|
# Override _recv_external_module_control_instruct to run only 3 iterations (test all commands)
|
|
def run_limited_iterations(self):
|
|
"""Modified method to process 3 commands and exit (avoids infinite loop)."""
|
|
for _ in range(3):
|
|
try:
|
|
# Acquire response lock and receive command
|
|
with self.response_lock:
|
|
control_cmd = self.recv_control_cmd_server.recv_control_cmd()
|
|
|
|
if control_cmd is None:
|
|
continue # Skip None commands
|
|
|
|
task_id = control_cmd["task_id"]
|
|
cmd = control_cmd["cmd"]
|
|
|
|
# Handle each command type
|
|
if cmd == "get_payload":
|
|
payload_info = self._get_current_server_info()
|
|
response = {"task_id": task_id, "result": payload_info}
|
|
with self.response_lock:
|
|
self.recv_control_cmd_server.response_for_control_cmd(task_id, response)
|
|
elif cmd == "get_metrics":
|
|
metrics_data = mock_get_metrics()
|
|
response = {"task_id": task_id, "result": metrics_data}
|
|
with self.response_lock:
|
|
self.recv_control_cmd_server.response_for_control_cmd(task_id, response)
|
|
elif cmd == "connect_rdma":
|
|
test_engine.engine_worker_queue.put_connect_rdma_task(control_cmd)
|
|
except Exception as e:
|
|
mock_logger.error(f"handle_control_cmd got error: {e}")
|
|
|
|
# Bind the modified method to the adapter instance
|
|
adapter._recv_external_module_control_instruct = run_limited_iterations.__get__(adapter)
|
|
# Execute the modified method to process test commands
|
|
adapter._recv_external_module_control_instruct()
|
|
|
|
# Verify 'get_payload' and 'get_metrics' triggered responses (2 total calls)
|
|
self.assertEqual(
|
|
mock_server_instance.response_for_control_cmd.call_count,
|
|
2,
|
|
"response_for_control_cmd should be called twice (get_payload + get_metrics)",
|
|
)
|
|
|
|
# Verify responses were sent for task IDs "1" and "2"
|
|
called_task_ids = [call_arg[0][0] for call_arg in mock_server_instance.response_for_control_cmd.call_args_list]
|
|
self.assertIn("1", called_task_ids, "Response not sent for 'get_payload' task (ID: 1)")
|
|
self.assertIn("2", called_task_ids, "Response not sent for 'get_metrics' task (ID: 2)")
|
|
|
|
# Verify 'connect_rdma' task was submitted to EngineWorkerQueue
|
|
self.assertEqual(
|
|
test_engine.engine_worker_queue.called_task["task_id"],
|
|
"3",
|
|
"connect_rdma task with ID 3 not received by EngineWorkerQueue",
|
|
)
|
|
|
|
# Test exception handling branch
|
|
def raise_test_exception(self):
|
|
"""Modified method to raise a test exception."""
|
|
raise ValueError("test_exception")
|
|
|
|
# Configure mock server to trigger exception
|
|
adapter.recv_control_cmd_server.recv_control_cmd = raise_test_exception.__get__(adapter)
|
|
# Execute to trigger exception
|
|
adapter._recv_external_module_control_instruct()
|
|
|
|
# Verify exception was logged
|
|
self.assertTrue(mock_logger.error.called, "Logger should capture exceptions during control command handling")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|