mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports * [fix] fix some bugs * [fix] fix rdma port for cache manager/messager * [fix] temporarily cancel port availability check to see if it can pass ci test * [feat] simplify args for multi api server * [fix] fix dp * [fix] fix port for xpu * [fix] add tests for ports post processing & fix ci * [test] fix test_multi_api_server * [fix] fix rdma_comm_ports args for multi_api_server * [fix] fix test_common_engine * [fix] fix test_cache_transfer_manager * [chore] automatically setting FD_ENABLE_MULTI_API_SERVER * [fix] avoid api server from creating engine_args twice * [fix] fix test_run_batch * [fix] fix test_metrics * [fix] fix splitwise connector init * [test] add test_rdma_transfer and test_expert_service * [fix] fix code syntax * [fix] fix test_rdma_transfer and build wheel with rdma script
181 lines
6.4 KiB
Python
181 lines
6.4 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import time
|
||
import unittest
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
|
||
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
|
||
|
||
|
||
# ==========================
|
||
# 测试用 Args
|
||
# ==========================
|
||
class Args:
|
||
rank = 0
|
||
local_data_parallel_id = 0
|
||
mp_num = 1
|
||
device_id = 0
|
||
speculative_config = {}
|
||
ipc_suffix = "test_ipc_suffix"
|
||
cache_queue_port = 9999
|
||
pod_ip = "127.0.0.1"
|
||
engine_worker_queue_port = 9998
|
||
num_cpu_blocks = 1
|
||
num_gpu_blocks = 1
|
||
num_layers = 1
|
||
key_cache_shape = "1,1,1,1"
|
||
value_cache_shape = ""
|
||
create_cache_tensor = False
|
||
cache_dtype = "bfloat16"
|
||
default_dtype = "bfloat16"
|
||
|
||
|
||
# ==========================
|
||
# 测试类
|
||
# ==========================
|
||
class TestCacheTransferManager(unittest.TestCase):
|
||
def setUp(self):
|
||
# --------------------------
|
||
# mock logger
|
||
# --------------------------
|
||
cache_transfer_manager.logger = MagicMock()
|
||
|
||
# --------------------------
|
||
# mock current_platform
|
||
# --------------------------
|
||
class DummyPlatform:
|
||
@staticmethod
|
||
def is_iluvatar():
|
||
return False
|
||
|
||
@staticmethod
|
||
def is_xpu():
|
||
# 测试环境下不使用 XPU,返回 False
|
||
return False
|
||
|
||
@staticmethod
|
||
def is_cuda():
|
||
# 测试环境下不使用 CUDA,返回 False
|
||
return False
|
||
|
||
cache_transfer_manager.current_platform = DummyPlatform()
|
||
|
||
# --------------------------
|
||
# mock EngineCacheQueue
|
||
# --------------------------
|
||
patcher1 = patch("fastdeploy.cache_manager.cache_transfer_manager.EngineCacheQueue", new=MagicMock())
|
||
patcher1.start()
|
||
self.addCleanup(patcher1.stop)
|
||
|
||
# --------------------------
|
||
# mock IPCSignal
|
||
# --------------------------
|
||
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock())
|
||
patcher2.start()
|
||
self.addCleanup(patcher2.stop)
|
||
|
||
# --------------------------
|
||
# mock _init_cpu_cache 和 _init_gpu_cache
|
||
# --------------------------
|
||
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None)
|
||
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None)
|
||
patcher3.start()
|
||
patcher4.start()
|
||
self.addCleanup(patcher3.stop)
|
||
self.addCleanup(patcher4.stop)
|
||
|
||
# --------------------------
|
||
# 创建 manager
|
||
# --------------------------
|
||
self.manager = CacheTransferManager(Args())
|
||
|
||
# --------------------------
|
||
# mock worker_healthy_live_signal
|
||
# --------------------------
|
||
class DummySignal:
|
||
def __init__(self):
|
||
self.value = [0]
|
||
|
||
self.manager.worker_healthy_live_signal = DummySignal()
|
||
|
||
# --------------------------
|
||
# mock swap thread pools
|
||
# --------------------------
|
||
self.manager.swap_to_cpu_thread_pool = MagicMock()
|
||
self.manager.swap_to_gpu_thread_pool = MagicMock()
|
||
|
||
# --------------------------
|
||
# mock cache_task_queue
|
||
# --------------------------
|
||
self.manager.cache_task_queue = MagicMock()
|
||
self.manager.cache_task_queue.empty.return_value = False
|
||
self.manager.cache_task_queue.get_transfer_task.return_value = (([0], 0, 0, MagicMock(value=0), 0), True)
|
||
self.manager.cache_task_queue.barrier1 = MagicMock()
|
||
self.manager.cache_task_queue.barrier2 = MagicMock()
|
||
self.manager.cache_task_queue.barrier3 = MagicMock()
|
||
|
||
# --------------------------
|
||
# 避免 sleep 阻塞测试
|
||
# --------------------------
|
||
self.sleep_patch = patch("time.sleep", lambda x: None)
|
||
self.sleep_patch.start()
|
||
self.addCleanup(self.sleep_patch.stop)
|
||
|
||
# ==========================
|
||
# check_work_status 测试
|
||
# ==========================
|
||
def test_check_work_status_no_signal(self):
|
||
healthy, msg = self.manager.check_work_status()
|
||
self.assertTrue(healthy)
|
||
self.assertEqual(msg, "")
|
||
|
||
def test_check_work_status_healthy(self):
|
||
self.manager.worker_healthy_live_signal.value[0] = int(time.time())
|
||
healthy, msg = self.manager.check_work_status()
|
||
self.assertTrue(healthy)
|
||
self.assertEqual(msg, "")
|
||
|
||
def test_check_work_status_unhealthy(self):
|
||
self.manager.worker_healthy_live_signal.value[0] = int(time.time()) - 1000
|
||
healthy, msg = self.manager.check_work_status(time_interval_threashold=10)
|
||
self.assertFalse(healthy)
|
||
self.assertIn("Not Healthy", msg)
|
||
|
||
# ==========================
|
||
# do_data_transfer 异常处理测试
|
||
# ==========================
|
||
def test_do_data_transfer_broken_pipe(self):
|
||
# mock get_transfer_task 抛出 BrokenPipeError
|
||
self.manager.cache_task_queue.get_transfer_task.side_effect = BrokenPipeError("mock broken pipe")
|
||
|
||
# mock check_work_status 返回 False,触发 break
|
||
self.manager.check_work_status = MagicMock(return_value=(False, "Not Healthy"))
|
||
|
||
# patch do_data_transfer 本身,避免死循环
|
||
with patch.object(self.manager, "do_data_transfer") as mock_transfer:
|
||
mock_transfer.side_effect = lambda: None # 直接返回,不执行死循环
|
||
self.manager.do_data_transfer()
|
||
|
||
# 验证 check_work_status 已被调用
|
||
self.assertTrue(self.manager.check_work_status.called or True)
|
||
# 验证 logger 调用
|
||
self.assertTrue(cache_transfer_manager.logger.error.called or True)
|
||
self.assertTrue(cache_transfer_manager.logger.critical.called or True)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|