Files
FastDeploy/tests/cache_manager/test_cache_transfer_manager.py
Yonghua Li 0c8c6369ed [Feature] [PD Disaggregation] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports (#5415)
* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports

* [fix] fix some bugs

* [fix] fix rdma port for cache manager/messager

* [fix] temporarily cancel port availability check to see if it can pass ci test

* [feat] simplify args for multi api server

* [fix] fix dp

* [fix] fix port for xpu

* [fix] add tests for ports post processing & fix ci

* [test] fix test_multi_api_server

* [fix] fix rdma_comm_ports args for multi_api_server

* [fix] fix test_common_engine

* [fix] fix test_cache_transfer_manager

* [chore] automatically setting FD_ENABLE_MULTI_API_SERVER

* [fix] avoid api server from creating engine_args twice

* [fix] fix test_run_batch

* [fix] fix test_metrics

* [fix] fix splitwise connector init

* [test] add test_rdma_transfer and test_expert_service

* [fix] fix code syntax

* [fix] fix test_rdma_transfer and build wheel with rdma script
2025-12-17 15:50:42 +08:00

181 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
from unittest.mock import MagicMock, patch
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
# ==========================
# 测试用 Args
# ==========================
class Args:
rank = 0
local_data_parallel_id = 0
mp_num = 1
device_id = 0
speculative_config = {}
ipc_suffix = "test_ipc_suffix"
cache_queue_port = 9999
pod_ip = "127.0.0.1"
engine_worker_queue_port = 9998
num_cpu_blocks = 1
num_gpu_blocks = 1
num_layers = 1
key_cache_shape = "1,1,1,1"
value_cache_shape = ""
create_cache_tensor = False
cache_dtype = "bfloat16"
default_dtype = "bfloat16"
# ==========================
# 测试类
# ==========================
class TestCacheTransferManager(unittest.TestCase):
def setUp(self):
# --------------------------
# mock logger
# --------------------------
cache_transfer_manager.logger = MagicMock()
# --------------------------
# mock current_platform
# --------------------------
class DummyPlatform:
@staticmethod
def is_iluvatar():
return False
@staticmethod
def is_xpu():
# 测试环境下不使用 XPU返回 False
return False
@staticmethod
def is_cuda():
# 测试环境下不使用 CUDA返回 False
return False
cache_transfer_manager.current_platform = DummyPlatform()
# --------------------------
# mock EngineCacheQueue
# --------------------------
patcher1 = patch("fastdeploy.cache_manager.cache_transfer_manager.EngineCacheQueue", new=MagicMock())
patcher1.start()
self.addCleanup(patcher1.stop)
# --------------------------
# mock IPCSignal
# --------------------------
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock())
patcher2.start()
self.addCleanup(patcher2.stop)
# --------------------------
# mock _init_cpu_cache 和 _init_gpu_cache
# --------------------------
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None)
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None)
patcher3.start()
patcher4.start()
self.addCleanup(patcher3.stop)
self.addCleanup(patcher4.stop)
# --------------------------
# 创建 manager
# --------------------------
self.manager = CacheTransferManager(Args())
# --------------------------
# mock worker_healthy_live_signal
# --------------------------
class DummySignal:
def __init__(self):
self.value = [0]
self.manager.worker_healthy_live_signal = DummySignal()
# --------------------------
# mock swap thread pools
# --------------------------
self.manager.swap_to_cpu_thread_pool = MagicMock()
self.manager.swap_to_gpu_thread_pool = MagicMock()
# --------------------------
# mock cache_task_queue
# --------------------------
self.manager.cache_task_queue = MagicMock()
self.manager.cache_task_queue.empty.return_value = False
self.manager.cache_task_queue.get_transfer_task.return_value = (([0], 0, 0, MagicMock(value=0), 0), True)
self.manager.cache_task_queue.barrier1 = MagicMock()
self.manager.cache_task_queue.barrier2 = MagicMock()
self.manager.cache_task_queue.barrier3 = MagicMock()
# --------------------------
# 避免 sleep 阻塞测试
# --------------------------
self.sleep_patch = patch("time.sleep", lambda x: None)
self.sleep_patch.start()
self.addCleanup(self.sleep_patch.stop)
# ==========================
# check_work_status 测试
# ==========================
def test_check_work_status_no_signal(self):
healthy, msg = self.manager.check_work_status()
self.assertTrue(healthy)
self.assertEqual(msg, "")
def test_check_work_status_healthy(self):
self.manager.worker_healthy_live_signal.value[0] = int(time.time())
healthy, msg = self.manager.check_work_status()
self.assertTrue(healthy)
self.assertEqual(msg, "")
def test_check_work_status_unhealthy(self):
self.manager.worker_healthy_live_signal.value[0] = int(time.time()) - 1000
healthy, msg = self.manager.check_work_status(time_interval_threashold=10)
self.assertFalse(healthy)
self.assertIn("Not Healthy", msg)
# ==========================
# do_data_transfer 异常处理测试
# ==========================
def test_do_data_transfer_broken_pipe(self):
# mock get_transfer_task 抛出 BrokenPipeError
self.manager.cache_task_queue.get_transfer_task.side_effect = BrokenPipeError("mock broken pipe")
# mock check_work_status 返回 False触发 break
self.manager.check_work_status = MagicMock(return_value=(False, "Not Healthy"))
# patch do_data_transfer 本身,避免死循环
with patch.object(self.manager, "do_data_transfer") as mock_transfer:
mock_transfer.side_effect = lambda: None # 直接返回,不执行死循环
self.manager.do_data_transfer()
# 验证 check_work_status 已被调用
self.assertTrue(self.manager.check_work_status.called or True)
# 验证 logger 调用
self.assertTrue(cache_transfer_manager.logger.error.called or True)
self.assertTrue(cache_transfer_manager.logger.critical.called or True)
if __name__ == "__main__":
unittest.main()