mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* fix async download bug
* update log
* Revert "update log"
This reverts commit 5816e602f4.
* update code
* fix mtp bug
177 lines
7.0 KiB
Python
177 lines
7.0 KiB
Python
import concurrent.futures
|
|
import pickle
|
|
import unittest
|
|
from dataclasses import asdict
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
|
|
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig
|
|
from fastdeploy.engine.args_utils import EngineArgs
|
|
from fastdeploy.engine.request import Request
|
|
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
|
|
|
|
|
class TestResourceManagerV1(unittest.TestCase):
|
|
def setUp(self):
|
|
max_num_seqs = 2
|
|
engine_args = EngineArgs(
|
|
max_num_seqs=max_num_seqs,
|
|
num_gpu_blocks_override=102,
|
|
max_num_batched_tokens=3200,
|
|
)
|
|
args = asdict(engine_args)
|
|
|
|
cache_cfg = CacheConfig(args)
|
|
model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing
|
|
speculative_cfg = SimpleNamespace(method=None)
|
|
model_cfg.print = print
|
|
model_cfg.max_model_len = 5120
|
|
cache_cfg.bytes_per_layer_per_block = 1
|
|
parallel_cfg = ParallelConfig(args)
|
|
scheduler_cfg = SchedulerConfig(args)
|
|
graph_opt_cfg = engine_args.create_graph_optimization_config()
|
|
|
|
fd_config = FDConfig(
|
|
model_config=model_cfg,
|
|
cache_config=cache_cfg,
|
|
parallel_config=parallel_cfg,
|
|
graph_opt_config=graph_opt_cfg,
|
|
speculative_config=speculative_cfg,
|
|
scheduler_config=scheduler_cfg,
|
|
)
|
|
self.manager = ResourceManagerV1(
|
|
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
|
|
)
|
|
req_dict = {
|
|
"request_id": "test_request",
|
|
"multimodal_inputs": {},
|
|
}
|
|
self.request = Request.from_dict(req_dict)
|
|
self.request.async_process_futures = []
|
|
self.request.multimodal_inputs = {}
|
|
|
|
def test_waiting_async_process_no_futures(self):
|
|
"""Test when there are no async process futures"""
|
|
result = self.manager.waiting_async_process(self.request)
|
|
self.assertFalse(result)
|
|
|
|
def test_waiting_async_process_future_done_no_error(self):
|
|
"""Test when future is done with no error"""
|
|
future = concurrent.futures.Future()
|
|
future.set_result(True)
|
|
self.request.async_process_futures = [future]
|
|
|
|
result = self.manager.waiting_async_process(self.request)
|
|
self.assertFalse(result)
|
|
self.assertEqual(len(self.request.async_process_futures), 0)
|
|
|
|
def test_waiting_async_process_future_done_with_error(self):
|
|
"""Test when future is done with error"""
|
|
future = concurrent.futures.Future()
|
|
future.set_result(True)
|
|
self.request.async_process_futures = [future]
|
|
self.request.error_message = "Download failed"
|
|
|
|
result = self.manager.waiting_async_process(self.request)
|
|
self.assertIsNone(result)
|
|
|
|
def test_waiting_async_process_future_not_done(self):
|
|
"""Test when future is not done"""
|
|
future = concurrent.futures.Future()
|
|
self.request.async_process_futures = [future]
|
|
|
|
result = self.manager.waiting_async_process(self.request)
|
|
self.assertTrue(result)
|
|
self.assertEqual(len(self.request.async_process_futures), 1)
|
|
|
|
def test_apply_async_preprocess(self):
|
|
"""Test applying async preprocess"""
|
|
with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit:
|
|
mock_submit.return_value = "mock_future"
|
|
self.manager.apply_async_preprocess(self.request)
|
|
|
|
mock_submit.assert_called_once_with(self.manager._download_features, self.request)
|
|
self.assertEqual(len(self.request.async_process_futures), 1)
|
|
self.assertEqual(self.request.async_process_futures[0], "mock_future")
|
|
|
|
@patch("fastdeploy.utils.init_bos_client")
|
|
@patch("fastdeploy.utils.download_from_bos")
|
|
def test_download_features_no_features(self, mock_download, mock_init):
|
|
"""Test when no features to download"""
|
|
self.request.multimodal_inputs = {}
|
|
result = self.manager._download_features(self.request)
|
|
self.assertIsNone(result)
|
|
mock_download.assert_not_called()
|
|
mock_init.assert_not_called()
|
|
|
|
def test_download_features_video_success(self):
|
|
"""Test successful video feature download"""
|
|
mock_client = MagicMock()
|
|
mock_client.get_object_as_string.return_value = pickle.dumps(np.array([[1, 2, 3]], dtype=np.float32))
|
|
|
|
self.request.multimodal_inputs = {"video_feature_urls": ["bos://bucket-name/path/to/object1"]}
|
|
|
|
self.manager.bos_client = mock_client
|
|
result = self.manager._download_features(self.request)
|
|
self.assertIsNone(result)
|
|
self.assertIn("video_features", self.request.multimodal_inputs)
|
|
self.assertIsInstance(self.request.multimodal_inputs["video_features"][0], np.ndarray)
|
|
|
|
def test_download_features_image_error(self):
|
|
"""Test image feature download with error"""
|
|
mock_client = MagicMock()
|
|
mock_client.get_object_as_string.side_effect = Exception("network error")
|
|
|
|
self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}
|
|
|
|
self.manager.bos_client = mock_client
|
|
result = self.manager._download_features(self.request)
|
|
self.assertIsNone(result)
|
|
self.assertIn(
|
|
"request test_request download features error",
|
|
self.request.error_message,
|
|
)
|
|
self.assertEqual(self.request.error_code, 530)
|
|
|
|
def test_download_features_audio_mixed(self):
|
|
"""Test mixed success/error in audio feature download"""
|
|
mock_client = MagicMock()
|
|
mock_client.get_object_as_string.side_effect = [
|
|
pickle.dumps(np.array([[1, 2, 3]], dtype=np.float32)),
|
|
Exception("timeout"),
|
|
]
|
|
|
|
self.request.multimodal_inputs = {
|
|
"audio_feature_urls": ["bos://bucket-name/path/to/object1", "bos://bucket-name/path/to/object2"]
|
|
}
|
|
|
|
self.manager.bos_client = mock_client
|
|
result = self.manager._download_features(self.request)
|
|
self.assertIsNone(result)
|
|
self.assertIn(
|
|
"request test_request download features error",
|
|
self.request.error_message,
|
|
)
|
|
self.assertEqual(self.request.error_code, 530)
|
|
|
|
def test_download_features_retry(self):
|
|
"""Test image feature download with error"""
|
|
mock_client = MagicMock()
|
|
mock_client.get_object_as_string.side_effect = Exception(
|
|
"Your request rate is too high. We have put limits on your bucket."
|
|
)
|
|
|
|
self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}
|
|
|
|
self.manager.bos_client = mock_client
|
|
result = self.manager._download_features(self.request)
|
|
self.assertIsNone(result)
|
|
self.assertIn("Failed after 1 retries for bos://bucket-name/path/to/object1", self.request.error_message)
|
|
self.assertEqual(self.request.error_code, 530)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|