""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import os import time import unittest from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, Mock, patch import numpy as np from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.utils import EngineError, ParameterError class DummyConfig(SimpleNamespace): def __getattr__(self, name): return None class TestEngineClient(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): """Set up test fixtures before each test method.""" # Create a properly configured tokenizer mock first mock_tokenizer = Mock() mock_tokenizer.sp_model = Mock() mock_tokenizer.sp_model.__len__ = Mock(return_value=1000) mock_tokenizer.vocab = Mock() mock_tokenizer.vocab.__len__ = Mock(return_value=1000) # Add len() method directly to the tokenizer mock mock_tokenizer.__len__ = Mock(return_value=1000) # Create a proper ModelConfig mock with enable_mm attribute mock_model_config = Mock() mock_model_config.enable_mm = True # Match engine_config.model_config.enable_mm mock_model_config.enable_logprob = True # Match engine_config.model_config.enable_logprob mock_model_config.max_model_len = 1024 # Create a mock FDConfig that contains the model_config mock_config = Mock() mock_config.model_config = mock_model_config mock_config.cache_config = Mock() mock_config.cache_config.max_processor_cache = 10 mock_config.cache_config.enable_prefix_caching = True mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False mock_config.parallel_config = Mock() mock_config.parallel_config.tensor_parallel_rank = 0 mock_config.parallel_config.local_data_parallel_id = 0 mock_config.parallel_config.tensor_parallel_size = 1 mock_config.scheduler_config = Mock() mock_config.scheduler_config.splitwise_role = None mock_config.limit_mm_per_prompt = 5 mock_config.mm_processor_kwargs = {} mock_config.tool_parser = None mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None # Create mocks for all the external dependencies mock_input_processor = Mock() mock_processor = Mock() mock_processor.tokenizer = mock_tokenizer # Set the tokenizer on the processor mock_input_processor.create_processor.return_value = mock_processor # Mock current platform mock_platform = Mock() mock_platform.is_iluvatar.return_value = False mock_platform.max_chips_per_node = 8 # Create mock IPCSignal that behaves properly mock_ipcsignal = Mock() mock_signal_instance = Mock() mock_signal_instance.value = np.array([0]) mock_ipcsignal.return_value = mock_signal_instance # Mock envs for FD_SUPPORT_MAX_CONNECTIONS mock_envs = Mock() mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 # Mock all the dependencies and external components with ( patch("fastdeploy.entrypoints.engine_client.IPCSignal"), patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager"), patch("fastdeploy.entrypoints.engine_client.InputPreprocessor"), patch("fastdeploy.entrypoints.engine_client.FileLock"), patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore"), patch.multiple( "fastdeploy.entrypoints.engine_client", InputPreprocessor=Mock(return_value=mock_input_processor), ZmqIpcClient=Mock, IPCSignal=mock_ipcsignal, StatefulSemaphore=Mock, DealerConnectionManager=Mock, FileLock=Mock, main_process_metrics=Mock(), current_platform=mock_platform, envs=mock_envs, ), patch("fastdeploy.metrics.metrics.main_process_metrics", Mock()), patch("os.getenv", return_value="50"), ): self.engine_config = DummyConfig( model_config=DummyConfig(enable_mm=True, enable_logprob=True, max_model_len=1024), cache_config=DummyConfig(enable_prefix_caching=True, max_processor_cache=10), scheduler_config=DummyConfig(splitwise_role="mixed", max_num_seqs=128), parallel_config=DummyConfig(tensor_parallel_size=1), structured_outputs_config=DummyConfig(reasoning_parser="reasoning_parser"), eplb_config=DummyConfig(enable_eplb=True, eplb_max_tokens=1024), ) # Create EngineClient instance with mocked dependencies self.engine_client = EngineClient(pid=1234, port=8080, fd_config=mock_config, workers=1) self.engine_client.zmq_client = MagicMock() self.engine_client.zmq_client = MagicMock() def test_engine_client_initialized_by_fd_config(self): for config_group_name, config_group in self.engine_config.__dict__.items(): for config_name, config_value in config_group.__dict__.items(): if hasattr(self.engine_client, config_name): # Skip enable_mm, enable_logprob, and enable_prefix_caching checks as they're handled differently in EngineClient if config_name in ["enable_mm", "enable_logprob", "enable_prefix_caching"]: continue assert getattr(self.engine_client, config_name) == config_value # Check enable_mm separately since it's copied from model_config assert getattr(self.engine_client, "enable_mm") == self.engine_config.model_config.enable_mm # Check enable_logprob separately since it's copied from model_config assert getattr(self.engine_client, "enable_logprob") == self.engine_config.model_config.enable_logprob # Check enable_prefix_caching separately since it's copied from cache_config assert ( getattr(self.engine_client, "enable_prefix_caching") == self.engine_config.cache_config.enable_prefix_caching ) # Set up mock attributes self.engine_client.data_processor = Mock() self.engine_client.data_processor.process_request_dict = Mock() self.engine_client.zmq_client = Mock() self.engine_client.zmq_client.send_json = Mock() self.engine_client.zmq_client.send_pyobj = Mock() self.engine_client.max_model_len = 1024 self.engine_client.enable_mm = False self.engine_client.max_logprobs = 20 self.engine_client.enable_logprob = True self.engine_client.ori_vocab_size = 1000 self.engine_client.enable_prefix_caching = False self.engine_client.enable_splitwise = False self.engine_client.disable_prefix_mm = False # Set up mock attributes for TestEngineClientValidParameters class too if hasattr(self, "engine_client_valid"): self.engine_client_valid.zmq_client = Mock() self.engine_client_valid.zmq_client.send_json = Mock() self.engine_client_valid.zmq_client.send_pyobj = Mock() # Mock IPC signals self.engine_client.worker_healthy_live_signal = Mock() self.engine_client.worker_healthy_live_signal.value = np.array([time.time()]) self.engine_client.model_weights_status_signal = Mock() self.engine_client.model_weights_status_signal.value = np.array([0]) # NORMAL self.engine_client.prefix_tree_status_signal = Mock() self.engine_client.prefix_tree_status_signal.value = np.array([0]) # NORMAL self.engine_client.kv_cache_status_signal = Mock() self.engine_client.kv_cache_status_signal.value = np.array([0]) # NORMAL # Mock file lock self.engine_client.clear_update_lock = Mock() self.engine_client.clear_update_lock.__enter__ = Mock(return_value=None) self.engine_client.clear_update_lock.__exit__ = Mock(return_value=None) async def test_add_request(self): request = { "request_id": "test-request-id", "chat_template_kwargs": {"enable_thinking": True}, "prompt_token_ids": [1], "chat_template": "Hello", "max_tokens": 20, "tools": [1], } await self.engine_client.add_requests(request) assert "chat_template" in request["chat_template_kwargs"], "'chat_template' not found in 'chat_template_kwargs" # assert "tools" in request["chat_template_kwargs"], "'tools' not found in 'chat_template_kwargs'" assert request["chat_template_kwargs"]["chat_template"] == "Hello" assert request["tools"] == [1] # assert request["chat_template_kwargs"]["tools"] == [1] class TestEngineClientValidParameters(unittest.TestCase): """Test cases for EngineClient.valid_parameters method""" def setUp(self): """Set up test fixtures for valid_parameters tests""" # Mock the dependencies mock_tokenizer = MagicMock() mock_tokenizer.sp_model = MagicMock() mock_tokenizer.sp_model.__len__ = MagicMock(return_value=1000) mock_tokenizer.vocab = MagicMock() mock_tokenizer.vocab.__len__ = MagicMock(return_value=1000) mock_data_processor = MagicMock() mock_data_processor.tokenizer = mock_tokenizer mock_model_config = MagicMock() mock_model_config.enable_mm = False # Mock config object mock_config = MagicMock() mock_config.model_config = mock_model_config mock_config.eplb_config = MagicMock() mock_config.eplb_config.enable_eplb = False mock_config.parallel_config = MagicMock() mock_config.parallel_config.tensor_parallel_rank = 0 mock_config.parallel_config.local_data_parallel_id = 0 mock_config.parallel_config.tensor_parallel_size = 1 # Add this missing attribute mock_config.scheduler_config = MagicMock() mock_config.scheduler_config.splitwise_role = None mock_config.cache_config = MagicMock() # Add cache_config mock_config.cache_config.enable_prefix_caching = False mock_config.cache_config.max_processor_cache = 0 mock_config.limit_mm_per_prompt = 5 # Add this attribute mock_config.mm_processor_kwargs = {} # Add this attribute mock_config.structured_outputs_config = MagicMock() # Add this mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None # Add this attribute # Mock IPCSignal to avoid file system dependencies with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: mock_ipcsignal.return_value = MagicMock() with patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore") as mock_semaphore: mock_semaphore.return_value = MagicMock() with patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager") as mock_connection_manager: mock_connection_manager.return_value = MagicMock() with patch("fastdeploy.entrypoints.engine_client.FileLock") as mock_filelock: mock_filelock.return_value = MagicMock() with patch("fastdeploy.config.ModelConfig") as mock_model_config_class: mock_model_config_class.return_value = mock_model_config with patch( "fastdeploy.entrypoints.engine_client.InputPreprocessor" ) as mock_input_processor: mock_input_processor_instance = MagicMock() mock_input_processor_instance.create_processor.return_value = mock_data_processor mock_input_processor.return_value = mock_input_processor_instance # Create EngineClient with minimal required parameters self.engine_client = EngineClient( pid=1234, port=8080, fd_config=mock_config, workers=1, ) # Set up mock attributes for TestEngineClientValidParameters class self.engine_client.zmq_client = Mock() self.engine_client.zmq_client.send_json = Mock() self.engine_client.zmq_client.send_pyobj = Mock() self.engine_client.max_logprobs = 20 self.engine_client.enable_logprob = True self.engine_client.ori_vocab_size = 1000 self.engine_client.enable_prefix_caching = False self.engine_client.enable_splitwise = False self.engine_client.disable_prefix_mm = False self.engine_client.max_model_len = 1024 self.engine_client.enable_mm = False self.engine_client.config = mock_config self.engine_client.max_chips_per_node = 8 self.engine_client.tensor_parallel_size = 1 self.engine_client.is_master = True self.engine_client.worker_healthy_live_signal = Mock() self.engine_client.worker_healthy_live_signal.value = np.array([0]) self.engine_client.model_weights_status_signal = Mock() self.engine_client.model_weights_status_signal.value = np.array([0]) self.engine_client.clear_update_lock = Mock() self.engine_client.clear_update_lock.__enter__ = Mock(return_value=None) self.engine_client.clear_update_lock.__exit__ = Mock(return_value=None) self.engine_client.kv_cache_status_signal = Mock() self.engine_client.kv_cache_status_signal.value = np.array([0]) self.engine_client.prefix_tree_status_signal = Mock() self.engine_client.prefix_tree_status_signal.value = np.array([0]) def test_max_logprobs_valid_values(self): """Test valid max_logprobs values""" # Test positive max_logprobs self.engine_client.max_logprobs = 20 data = {"request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # Test -1 (unlimited) self.engine_client.max_logprobs = -1 data = {"request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise def test_max_logprobs_invalid_values(self): """Test invalid max_logprobs values""" # Test negative value less than -1 self.engine_client.max_logprobs = -2 data = {"request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("max_logprobs", str(context.exception)) self.assertIn("must be >= -1", str(context.exception)) self.assertIn("got -2", str(context.exception)) def test_max_logprobs_exceeds_vocab_size(self): """Test max_logprobs exceeding vocab_size""" self.engine_client.max_logprobs = 1500 self.engine_client.ori_vocab_size = 1000 data = {"request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("max_logprobs", str(context.exception)) self.assertIn("must be <= vocab_size", str(context.exception)) self.assertIn("1000", str(context.exception)) self.assertIn("got 1500", str(context.exception)) def test_max_logprobs_unlimited(self): """Test max_logprobs = -1 (unlimited) sets to ori_vocab_size""" self.engine_client.max_logprobs = -1 self.engine_client.ori_vocab_size = 1000 data = {"request_id": "test"} # This should not raise and internally max_logprobs should be set to ori_vocab_size self.engine_client.valid_parameters(data) # Should not raise # The actual max_logprobs value should be set to ori_vocab_size internally self.assertEqual(self.engine_client.max_logprobs, -1) # Original value remains unchanged def test_prompt_logprobs_valid_values(self): """Test valid prompt_logprobs values""" self.engine_client.max_logprobs = 20 self.engine_client.enable_logprob = True # Test valid positive value with FD_USE_GET_SAVE_OUTPUT_V1=1 with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): data = {"prompt_logprobs": 10, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # Test -1 (unlimited) with FD_USE_GET_SAVE_OUTPUT_V1=1 with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): self.engine_client.max_logprobs = -1 data = {"prompt_logprobs": -1, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # Test None (default) data = {"request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise def test_prompt_logprobs_unlimited_sets_to_vocab_size(self): """Test prompt_logprobs = -1 sets to ori_vocab_size""" self.engine_client.max_logprobs = -1 # Set to unlimited to allow prompt_logprobs = -1 self.engine_client.enable_logprob = True self.engine_client.ori_vocab_size = 1000 with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): data = {"prompt_logprobs": -1, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # prompt_logprobs should be set to ori_vocab_size internally def test_prompt_logprobs_disabled_when_fd_use_get_save_output_v1_disabled(self): """Test prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is disabled""" with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): data = {"prompt_logprobs": 10, "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn( "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(context.exception) ) def test_prompt_logprobs_disabled_logprob(self): """Test prompt_logprobs when logprob is disabled""" self.engine_client.enable_logprob = False data = {"prompt_logprobs": 10, "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn("`enable_logprob` is disabled, please enable it in startup config.", str(context.exception)) def test_prompt_logprobs_disabled_when_prefix_caching_enabled(self): """Test prompt_logprobs when prefix caching is enabled""" self.engine_client.enable_prefix_caching = True self.engine_client.enable_logprob = True with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): data = {"prompt_logprobs": 10, "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn("prompt_logprobs is not support when prefix caching is enabled", str(context.exception)) def test_prompt_logprobs_invalid_values(self): """Test invalid prompt_logprobs values""" self.engine_client.enable_logprob = True # Test negative value less than -1 with FD_USE_GET_SAVE_OUTPUT_V1=1 with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): data = {"prompt_logprobs": -2, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("prompt_logprobs", str(context.exception)) self.assertIn("must be a non-negative value or -1", str(context.exception)) self.assertIn("current value is -2", str(context.exception)) def test_prompt_logprobs_exceeds_max_logprobs(self): """Test prompt_logprobs exceeding max_logprobs""" self.engine_client.max_logprobs = 10 self.engine_client.enable_logprob = True with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): data = {"prompt_logprobs": 15, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("prompt_logprobs", str(context.exception)) self.assertIn("exceeds maximum allowed value", str(context.exception)) self.assertIn("15", str(context.exception)) self.assertIn("10", str(context.exception)) def test_top_logprobs_validation_with_fd_use_get_save_output_v1_enabled(self): """Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is enabled""" self.engine_client.max_logprobs = 20 self.engine_client.enable_logprob = True with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): # Test -1 (unlimited) - should set to ori_vocab_size, but need max_logprobs also to be -1 self.engine_client.max_logprobs = -1 # Set to unlimited to allow top_logprobs = -1 data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # Reset max_logprobs for other tests self.engine_client.max_logprobs = 20 # Test valid positive value data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # Test value less than -1 - should raise ValueError data = {"logprobs": True, "top_logprobs": -2, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("must be a non-negative value or -1", str(context.exception)) self.assertIn("current value is -2", str(context.exception)) # Test value exceeding max_logprobs - should raise ValueError data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("exceeds maximum allowed value", str(context.exception)) def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self): """Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is disabled""" self.engine_client.max_logprobs = 20 self.engine_client.enable_logprob = True with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): # Test negative value - should raise ValueError data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("top_logprobs must be between 0 and 20", str(context.exception)) self.assertIn("current value is -1", str(context.exception)) # Test value > 20 - should raise ValueError data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) self.assertIn("top_logprobs must be between 0 and 20", str(context.exception)) self.assertIn("current value is 25", str(context.exception)) # Test valid value data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise def test_top_logprobs_disabled_logprob(self): """Test top_logprobs when logprob is disabled""" self.engine_client.enable_logprob = False data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn("disabled", str(context.exception)) def test_top_logprobs_invalid_type(self): """Test top_logprobs with invalid type""" self.engine_client.enable_logprob = True # Test with string type data = {"logprobs": True, "top_logprobs": "10", "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn("top_logprobs", str(context.exception)) self.assertIn("Invalid type", str(context.exception)) self.assertIn("expected int", str(context.exception)) def test_logprobs_invalid_type(self): """Test logprobs with invalid type""" self.engine_client.enable_logprob = True # Test with string type data = {"logprobs": "true", "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn("logprobs", str(context.exception)) self.assertIn("Invalid type", str(context.exception)) def test_logprobs_disabled(self): """Test logprobs when logprob is disabled""" self.engine_client.enable_logprob = False # Test with logprobs=True data = {"logprobs": True, "request_id": "test"} with self.assertRaises(ParameterError) as context: self.engine_client.valid_parameters(data) self.assertIn("disabled", str(context.exception)) def test_unlimited_max_logprobs_with_prompt_logprobs(self): """Test unlimited max_logprobs (-1) with prompt_logprobs""" self.engine_client.max_logprobs = -1 # Unlimited self.engine_client.enable_logprob = True with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): # Should allow any prompt_logprobs value data = {"prompt_logprobs": 1000, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise def test_unlimited_max_logprobs_with_top_logprobs(self): """Test unlimited max_logprobs (-1) with top_logprobs""" self.engine_client.max_logprobs = -1 # Unlimited self.engine_client.enable_logprob = True with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): # Should allow any top_logprobs value data = {"logprobs": True, "top_logprobs": 1000, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise def test_edge_case_zero_values(self): """Test edge cases with zero values""" self.engine_client.max_logprobs = 20 self.engine_client.enable_logprob = True # Test prompt_logprobs = 0 with FD_USE_GET_SAVE_OUTPUT_V1=1 with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): data = {"prompt_logprobs": 0, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise # Test top_logprobs = 0 with FD_USE_GET_SAVE_OUTPUT_V1=0 with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): data = {"logprobs": True, "top_logprobs": 0, "request_id": "test"} self.engine_client.valid_parameters(data) # Should not raise def test_valid_parameters(self): request = { "request_id": "test-request-id", "chat_template_kwargs": {"enable_thinking": True}, "prompt_token_ids": [1], "chat_template": "Hello", "max_tokens": 20, "tools": [1], "temperature": 0, } self.engine_client.valid_parameters(request) assert request["temperature"] == 1e-6 async def test_init_basic_parameters(self): """Test EngineClient initialization with basic parameters.""" # Create a proper ModelConfig mock with enable_mm attribute mock_model_config = Mock() mock_model_config.enable_mm = False # Create mocks for all the external dependencies mock_input_processor = Mock() mock_processor = Mock() mock_input_processor.create_processor.return_value = mock_processor # Mock current platform mock_platform = Mock() mock_platform.is_iluvatar.return_value = False # Create mock IPCSignal that behaves properly mock_ipcsignal = Mock() mock_signal_instance = Mock() mock_signal_instance.value = np.array([0]) mock_ipcsignal.return_value = mock_signal_instance # Mock envs for FD_SUPPORT_MAX_CONNECTIONS mock_envs = Mock() mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 with ( patch.multiple( "fastdeploy.entrypoints.engine_client", InputPreprocessor=Mock(return_value=mock_input_processor), current_platform=mock_platform, IPCSignal=mock_ipcsignal, StatefulSemaphore=Mock, DealerConnectionManager=Mock, FileLock=Mock, work_process_metrics=Mock(), envs=mock_envs, ), patch("os.getenv", return_value="50"), ): # Create a mock config for this test mock_config = Mock() mock_config.model_config = Mock() mock_config.model_config.enable_mm = False client = EngineClient( model_name_or_path="test_model", tokenizer=Mock(), max_model_len=2048, tensor_parallel_size=2, pid=5678, port=9090, limit_mm_per_prompt=3, mm_processor_kwargs={"test": "value"}, config=mock_config, reasoning_parser=None, data_parallel_size=1, enable_logprob=False, workers=2, tool_parser=None, enable_prefix_caching=True, splitwise_role="master", max_processor_cache=100, ) self.assertEqual(client.max_model_len, 2048) self.assertEqual(client.enable_logprob, False) self.assertEqual(client.enable_prefix_caching, True) self.assertEqual(client.enable_splitwise, True) async def test_format_and_add_data_without_request_id(self): """Test format_and_add_data adds request_id when missing.""" prompts = {"prompt_token_ids": [1, 2, 3], "max_tokens": 50} with patch.object(self.engine_client, "add_requests") as mock_add: mock_add.return_value = None result = await self.engine_client.format_and_add_data(prompts) self.assertIn("request_id", prompts) self.assertEqual(result, prompts["prompt_token_ids"]) mock_add.assert_called_once_with(prompts) async def test_format_and_add_data_with_max_tokens_default(self): """Test format_and_add_data sets default max_tokens when missing.""" prompts = {"request_id": "test-id", "prompt_token_ids": [1, 2, 3]} with patch.object(self.engine_client, "add_requests") as mock_add: mock_add.return_value = None await self.engine_client.format_and_add_data(prompts) self.assertEqual(prompts["max_tokens"], self.engine_client.max_model_len - 1) async def test_check_mm_disable_prefix_cache_with_disabled_cache(self): """Test _check_mm_disable_prefix_cache when prefix cache is disabled.""" self.engine_client.disable_prefix_mm = False task = {"multimodal_inputs": {"token_type_ids": [1, 2, 3]}} result = self.engine_client._check_mm_disable_prefix_cache(task) self.assertFalse(result) async def test_check_mm_disable_prefix_cache_with_no_multimodal_data(self): """Test _check_mm_disable_prefix_cache with no multimodal inputs.""" self.engine_client.disable_prefix_mm = True task = {"multimodal_inputs": []} result = self.engine_client._check_mm_disable_prefix_cache(task) self.assertFalse(result) async def test_check_mm_disable_prefix_cache_with_multimodal_data(self): """Test _check_mm_disable_prefix_cache detects multimodal data.""" self.engine_client.disable_prefix_mm = True task = {"multimodal_inputs": {"token_type_ids": [1, 0, 2]}} result = self.engine_client._check_mm_disable_prefix_cache(task) self.assertTrue(result) async def test_add_requests_successful_processing(self): """Test successful request processing in add_requests.""" task = { "request_id": "test-id", "chat_template_kwargs": {"existing": "value"}, "chat_template": "test_template", "prompt_token_ids": [1, 2, 3, 4, 5], "max_tokens": 100, "min_tokens": 1, "messages": "test message", } self.engine_client.data_processor.process_request_dict = Mock() with patch.object(self.engine_client, "_send_task") as mock_send: await self.engine_client.add_requests(task) self.assertEqual(task["chat_template_kwargs"]["chat_template"], "test_template") self.assertEqual(task["prompt_token_ids_len"], 5) self.assertNotIn("messages", task) mock_send.assert_called_once() async def test_add_requests_with_coroutine_processor(self): """Test add_requests with async processor.""" task = {"request_id": "test-id", "prompt_token_ids": [1, 2, 3], "max_tokens": 100} async_mock = AsyncMock() self.engine_client.data_processor.process_request_dict = async_mock with patch.object(self.engine_client, "_send_task"): await self.engine_client.add_requests(task) async_mock.assert_called_once() async def test_add_requests_with_multimodal_prefix_cache_error(self): """Test add_requests raises error for multimodal data with prefix cache.""" self.engine_client.enable_mm = True self.engine_client.enable_prefix_caching = True self.engine_client.disable_prefix_mm = True task = { "request_id": "test-id", "prompt_token_ids": [1, 2, 3], "multimodal_inputs": {"token_type_ids": [1, 0, 1]}, } with self.assertRaises(Exception): # EngineError await self.engine_client.add_requests(task) async def test_add_requests_input_length_validation_error(self): """Test add_requests validation for input length.""" task = {"request_id": "test-id", "prompt_token_ids": list(range(1024)), "min_tokens": 1} # At max length with self.assertRaises(Exception): # EngineError await self.engine_client.add_requests(task) async def test_add_requests_stop_sequences_validation(self): """Test add_requests validation for stop sequences.""" task = { "request_id": "test-id", "prompt_token_ids": [1, 2, 3], "stop_seqs_len": list(range(25)), # Exceeds default limit } with patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs: mock_envs.FD_MAX_STOP_SEQS_NUM = 20 mock_envs.FD_STOP_SEQS_MAX_LEN = 100 with self.assertRaises(Exception): # EngineError await self.engine_client.add_requests(task) async def test_add_requests_with_n_parameter_multiple_requests(self): """Test add_requests with n parameter for multiple requests.""" task = {"request_id": "test-id_1", "prompt_token_ids": [1, 2, 3], "n": 3, "max_tokens": 100} with patch.object(self.engine_client, "_send_task") as mock_send: await self.engine_client.add_requests(task) # Should send 3 tasks with indices 3, 4, 5 (1*3 to (1+1)*3) self.assertEqual(mock_send.call_count, 3) def test_send_task_without_multimodal(self): """Test _send_task for non-multimodal content.""" self.engine_client.enable_mm = False task = {"test": "data"} self.engine_client._send_task(task) self.engine_client.zmq_client.send_json.assert_called_once_with(task) def test_send_task_with_multimodal(self): """Test _send_task for multimodal content.""" self.engine_client.enable_mm = True task = {"test": "multimodal_data"} self.engine_client._send_task(task) self.engine_client.zmq_client.send_pyobj.assert_called_once_with(task) def test_valid_parameters_max_tokens_valid(self): """Test valid_parameters accepts valid max_tokens.""" data = {"max_tokens": 100} # Should not raise exception self.engine_client.valid_parameters(data) def test_valid_parameters_max_tokens_too_small(self): """Test valid_parameters rejects max_tokens < 1.""" data = {"max_tokens": 0} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_max_tokens_too_large(self): """Test valid_parameters rejects max_tokens >= max_model_len.""" data = {"max_tokens": 2048} # Equal to max_model_len, should raise exception with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_reasoning_max_tokens_adjustment(self): """Test valid_parameters adjusts reasoning_max_tokens when needed.""" data = {"max_tokens": 50, "reasoning_max_tokens": 100, "request_id": "test-id"} # Larger than max_tokens with patch("fastdeploy.entrypoints.engine_client.api_server_logger") as mock_logger: self.engine_client.valid_parameters(data) self.assertEqual(data["reasoning_max_tokens"], 50) mock_logger.warning.assert_called_once() def test_valid_parameters_temperature_zero_adjustment(self): """Test valid_parameters adjusts zero temperature.""" data = {"temperature": 0} self.engine_client.valid_parameters(data) self.assertEqual(data["temperature"], 1e-6) def test_valid_parameters_logprobs_disabled_when_enabled(self): """Test valid_parameters rejects logprobs when disabled.""" self.engine_client.enable_logprob = False data = {"logprobs": True} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_logprobs_with_invalid_type(self): """Test valid_parameters rejects invalid logprobs type.""" data = {"logprobs": "invalid"} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_top_logprobs_disabled(self): """Test valid_parameters rejects top_logprobs when disabled.""" self.engine_client.enable_logprob = False data = {"logprobs": True, "top_logprobs": 5} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_top_logprobs_invalid_type(self): """Test valid_parameters rejects invalid top_logprobs type.""" self.engine_client.enable_logprob = True data = {"logprobs": True, "top_logprobs": "invalid"} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_top_logprobs_negative(self): """Test valid_parameters rejects negative top_logprobs.""" self.engine_client.enable_logprob = True data = {"logprobs": True, "top_logprobs": -1} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_top_logprobs_too_large(self): """Test valid_parameters rejects top_logprobs > 20.""" self.engine_client.enable_logprob = True data = {"logprobs": True, "top_logprobs": 25} with self.assertRaises(Exception): # ParameterError self.engine_client.valid_parameters(data) def test_valid_parameters_top_logprobs_valid(self): """Test valid_parameters accepts valid top_logprobs.""" self.engine_client.enable_logprob = True data = {"logprobs": True, "top_logprobs": 10} # Should not raise exception self.engine_client.valid_parameters(data) def test_check_health_healthy(self): """Test check_health returns healthy status.""" self.engine_client.worker_healthy_live_signal.value = np.array([time.time()]) result, message = self.engine_client.check_health() self.assertTrue(result) self.assertEqual(message, "") def test_check_health_unhealthy_timeout(self): """Test check_health returns unhealthy due to timeout.""" # Set signal to old time (more than 30 seconds ago) old_time = time.time() - 60 self.engine_client.worker_healthy_live_signal.value = np.array([old_time]) result, message = self.engine_client.check_health(time_interval_threashold=30) self.assertFalse(result) self.assertEqual(message, "Worker Service Not Healthy") def test_is_workers_alive_normal(self): """Test is_workers_alive returns True when weights are normal.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.NORMAL = 0 self.engine_client.model_weights_status_signal.value = np.array([0]) result, message = self.engine_client.is_workers_alive() self.assertTrue(result) self.assertEqual(message, "") def test_is_workers_alive_no_weights(self): """Test is_workers_alive returns False when no weights.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.NORMAL = 0 self.engine_client.model_weights_status_signal.value = np.array([1]) result, message = self.engine_client.is_workers_alive() self.assertFalse(result) self.assertEqual(message, "No model weight enabled") def test_update_model_weight_already_normal(self): """Test update_model_weight when weights are already normal.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.NORMAL = 0 self.engine_client.model_weights_status_signal.value = np.array([0]) result, message = self.engine_client.update_model_weight() self.assertTrue(result) self.assertEqual(message, "") def test_update_model_weight_already_updating(self): """Test update_model_weight when already updating.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.NORMAL = 0 mock_status.UPDATING = 1 self.engine_client.model_weights_status_signal.value = np.array([1]) result, message = self.engine_client.update_model_weight() self.assertFalse(result) self.assertEqual(message, "worker is updating model weight already") def test_update_model_weight_clearing(self): """Test update_model_weight when clearing weights.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.NORMAL = 0 mock_status.CLEARING = -1 self.engine_client.model_weights_status_signal.value = np.array([-1]) result, message = self.engine_client.update_model_weight() self.assertFalse(result) self.assertEqual(message, "worker is clearing model weight, cannot update now") def test_update_model_weight_timeout(self): """Test update_model_weight timeout scenario.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: with patch("fastdeploy.entrypoints.engine_client.KVCacheStatus") as mock_kv_status: with patch("fastdeploy.entrypoints.engine_client.PrefixTreeStatus") as mock_prefix_status: mock_status.NORMAL = 0 mock_status.UPDATING = 1 mock_status.CLEARED = -2 mock_kv_status.NORMAL = 0 mock_kv_status.UPDATING = 1 mock_kv_status.CLEARED = -2 mock_prefix_status.NORMAL = 0 mock_prefix_status.UPDATING = 1 mock_prefix_status.CLEARED = -2 self.engine_client.enable_prefix_caching = True # Start with CLEARED status to enter the updating loop self.engine_client.model_weights_status_signal.value = np.array([-2]) self.engine_client.kv_cache_status_signal.value = np.array([-2]) # Start as CLEARED self.engine_client.prefix_tree_status_signal.value = np.array([-2]) # Start as CLEARED result, message = self.engine_client.update_model_weight(timeout=1) self.assertFalse(result) self.assertEqual(message, "Update model weight timeout") def test_clear_load_weight_already_cleared(self): """Test clear_load_weight when weights are already cleared.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.CLEARED = -2 self.engine_client.model_weights_status_signal.value = np.array([-2]) result, message = self.engine_client.clear_load_weight() self.assertTrue(result) self.assertEqual(message, "") def test_clear_load_weight_already_clearing(self): """Test clear_load_weight when already clearing.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.CLEARED = -2 mock_status.CLEARING = -1 self.engine_client.model_weights_status_signal.value = np.array([-1]) result, message = self.engine_client.clear_load_weight() self.assertFalse(result) self.assertEqual(message, "worker is clearing model weight already") def test_clear_load_weight_updating(self): """Test clear_load_weight when updating weights.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: mock_status.CLEARED = -2 mock_status.CLEARING = -1 mock_status.UPDATING = 1 self.engine_client.model_weights_status_signal.value = np.array([1]) result, message = self.engine_client.clear_load_weight() self.assertFalse(result) self.assertEqual(message, "worker is updating model weight, cannot clear now") def test_clear_load_weight_timeout(self): """Test clear_load_weight timeout scenario.""" with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: with patch("fastdeploy.entrypoints.engine_client.KVCacheStatus") as mock_kv_status: with patch("fastdeploy.entrypoints.engine_client.PrefixTreeStatus") as mock_prefix_status: mock_status.NORMAL = 0 mock_status.CLEARED = -2 mock_status.CLEARING = -1 mock_kv_status.CLEARED = -2 mock_kv_status.CLEARING = -1 mock_prefix_status.CLEARED = -2 mock_prefix_status.CLEARING = -1 self.engine_client.enable_prefix_caching = True # Start with NORMAL status to enter the clearing loop self.engine_client.model_weights_status_signal.value = np.array([0]) self.engine_client.kv_cache_status_signal.value = np.array([0]) # Start as NORMAL self.engine_client.prefix_tree_status_signal.value = np.array([0]) # Start as NORMAL result, message = self.engine_client.clear_load_weight(timeout=1) self.assertFalse(result) self.assertEqual(message, "Clear model weight timeout") def test_check_model_weight_status(self): """Test check_model_weight_status returns correct status.""" # Status < 0 indicates abnormal self.engine_client.model_weights_status_signal.value = np.array([-1]) result = self.engine_client.check_model_weight_status() self.assertTrue(result) # Status >= 0 indicates normal self.engine_client.model_weights_status_signal.value = np.array([0]) result = self.engine_client.check_model_weight_status() self.assertFalse(result) def test_create_zmq_client(self): """Test create_zmq_client method.""" mock_zmq_client = Mock() with patch("fastdeploy.entrypoints.engine_client.ZmqIpcClient", return_value=mock_zmq_client) as mock_zmq: self.engine_client.create_zmq_client("test_model", "test_mode") mock_zmq.assert_called_once_with("test_model", "test_mode") mock_zmq_client.connect.assert_called_once() self.assertEqual(self.engine_client.zmq_client, mock_zmq_client) async def test_init_with_multimodal_prefix_cache(self): """Test EngineClient initialization with multimodal prefix cache enabled.""" mock_model_config = Mock() mock_model_config.enable_mm = True mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False with ( patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class, patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform, patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal, patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs, patch("os.getenv", return_value="50"), patch("fastdeploy.cache_manager.cache_data.is_mm_model_disable_prefix_cache", return_value=True), ): mock_platform.is_iluvatar.return_value = False mock_input_processor = Mock() mock_processor_class.return_value = mock_input_processor mock_processor = Mock() mock_input_processor.create_processor.return_value = mock_processor mock_signal_instance = Mock() mock_signal_instance.value = np.array([0]) mock_ipcsignal.return_value = mock_signal_instance mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 client = EngineClient( model_name_or_path="test_model", tokenizer=Mock(), max_model_len=2048, tensor_parallel_size=1, pid=5678, port=8080, limit_mm_per_prompt=5, mm_processor_kwargs={}, config=mock_config, reasoning_parser=None, data_parallel_size=1, enable_logprob=True, workers=1, tool_parser=None, enable_prefix_caching=True, # Enable prefix caching splitwise_role=None, max_processor_cache=0, ) self.assertTrue(client.enable_mm) self.assertTrue(client.enable_prefix_caching) self.assertTrue(client.disable_prefix_mm) async def test_init_as_worker_node(self): """Test EngineClient initialization as worker node (not master).""" mock_model_config = Mock() mock_model_config.enable_mm = False mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False with ( patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class, patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform, patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal, patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs, patch("os.getenv", return_value="50"), ): mock_platform.is_iluvatar.return_value = False mock_platform.max_chips_per_node = 8 mock_input_processor = Mock() mock_processor_class.return_value = mock_input_processor mock_processor = Mock() mock_input_processor.create_processor.return_value = mock_processor mock_signal_instance = Mock() mock_signal_instance.value = np.array([0]) mock_ipcsignal.return_value = mock_signal_instance mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 # Use tensor_parallel_size > max_chips_per_node to make it a worker client = EngineClient( model_name_or_path="test_model", tokenizer=Mock(), max_model_len=2048, tensor_parallel_size=16, # Large number to make it a worker pid=5678, port=8080, limit_mm_per_prompt=5, mm_processor_kwargs={}, config=mock_config, reasoning_parser=None, data_parallel_size=1, enable_logprob=True, workers=1, tool_parser=None, enable_prefix_caching=False, splitwise_role=None, max_processor_cache=0, ) self.assertFalse(client.is_master) async def test_format_and_add_data(self): """Test format_and_add_data method.""" prompts = {"prompt_token_ids": [1, 2, 3], "max_tokens": 50} with patch.object(self.engine_client, "add_requests") as mock_add: mock_add.return_value = None await self.engine_client.format_and_add_data(prompts) mock_add.assert_called_once() call_args = mock_add.call_args[0][0] self.assertIn("request_id", call_args) self.assertEqual(call_args["prompt_token_ids"], [1, 2, 3]) self.assertEqual(call_args["max_tokens"], 50) async def test_rearrange_experts_disabled(self): """Test rearrange_experts when EPLB is disabled.""" mock_config = Mock() mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False self.engine_client.config = mock_config request_dict = {"user": "test", "passwd": "test"} content, status_code = await self.engine_client.rearrange_experts(request_dict) self.assertEqual(content["code"], 1) self.assertEqual(content["msg"], "redundant expert is disabled") self.assertEqual(status_code, 400) async def test_get_per_expert_tokens_stats_disabled(self): """Test get_per_expert_tokens_stats when EPLB is disabled.""" mock_config = Mock() mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False self.engine_client.config = mock_config request_dict = {"user": "test", "passwd": "test"} content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) self.assertEqual(content["code"], 1) self.assertEqual(content["msg"], "redundant expert is disabled") self.assertEqual(status_code, 400) async def test_get_per_expert_tokens_stats_invalid_auth(self): """Test get_per_expert_tokens_stats with invalid authentication.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "correct_user" mock_eplb_config.redundant_expert_api_password = "correct_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config request_dict = {"user": "wrong_user", "passwd": "wrong_pass"} content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) self.assertEqual(content["code"], 1) self.assertEqual(content["msg"], "user or passwd is invalid") self.assertEqual(status_code, 401) async def test_get_per_expert_tokens_stats_success(self): """Test get_per_expert_tokens_stats successful response.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config # Set up mock arrays mock_local_stats = Mock() mock_local_stats.value = np.array([1, 2, 3]) self.engine_client.local_experts_token_stats_array_list = [mock_local_stats] self.engine_client.signal_clear_experts_token_stats_list = [] request_dict = {"user": "test_user", "passwd": "test_pass"} content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) self.assertEqual(content["code"], 0) self.assertEqual(content["msg"], "ok") self.assertIn("data", content) self.assertEqual(content["data"], [[1, 2, 3]]) self.assertEqual(status_code, 200) async def test_get_per_expert_tokens_stats_clear_stat(self): """Test get_per_expert_tokens_stats with clear_stat flag.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config # Set up mock arrays and signals mock_clear_signal = Mock() mock_clear_signal.value = np.array([0]) self.engine_client.signal_clear_experts_token_stats_list = [mock_clear_signal] mock_local_stats = Mock() mock_local_stats.value = np.array([1, 2, 3]) self.engine_client.local_experts_token_stats_array_list = [mock_local_stats] request_dict = {"user": "test_user", "passwd": "test_pass", "clear_stat": True} content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) self.assertEqual(content["code"], 0) self.assertEqual(content["msg"], "ok") self.assertEqual(mock_clear_signal.value[0], 1) # Clear signal should be set self.assertEqual(status_code, 200) async def test_check_redundant_disabled(self): """Test check_redundant when EPLB is disabled.""" mock_config = Mock() mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False self.engine_client.config = mock_config request_dict = {"user": "test", "passwd": "test"} content, status_code = await self.engine_client.check_redundant(request_dict) self.assertEqual(content["code"], 1) self.assertEqual(content["msg"], "redundant expert is disabled") self.assertEqual(status_code, 400) async def test_check_redundant_invalid_auth(self): """Test check_redundant with invalid authentication.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "correct_user" mock_eplb_config.redundant_expert_api_password = "correct_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config request_dict = {"user": "wrong_user", "passwd": "wrong_pass"} content, status_code = await self.engine_client.check_redundant(request_dict) self.assertEqual(content["code"], 1) self.assertEqual(content["msg"], "user or passwd is invalid") self.assertEqual(status_code, 401) async def test_check_redundant_wrong_rank(self): """Test check_redundant with wrong tensor parallel rank.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 1 # Not rank 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config request_dict = {"user": "test_user", "passwd": "test_pass"} content, status_code = await self.engine_client.check_redundant(request_dict) self.assertEqual(content["code"], 1) self.assertIn("actual rank 1, expect rank 0", content["msg"]) self.assertEqual(status_code, 400) async def test_check_redundant_status_unknown(self): """Test check_redundant with unknown status (invalid signal value).""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config self.engine_client.rearrange_experts_signal = Mock() self.engine_client.rearrange_experts_signal.value = np.array([999]) # Invalid status with patch("fastdeploy.entrypoints.engine_client.RearrangeExpertStatus") as mock_status: mock_status.side_effect = Exception("Invalid status") request_dict = {"user": "test_user", "passwd": "test_pass", "action": ""} content, status_code = await self.engine_client.check_redundant(request_dict) self.assertEqual(content["code"], 0) self.assertEqual(content["msg"], "ok") self.assertEqual(content["status"], "unknown") # Should fallback to unknown self.assertEqual(status_code, 200) async def test_check_redundant_status_known(self): """Test check_redundant with known status.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config self.engine_client.rearrange_experts_signal = Mock() self.engine_client.rearrange_experts_signal.value = np.array([0]) # FREE status with patch("fastdeploy.entrypoints.engine_client.RearrangeExpertStatus") as mock_status: mock_status_instance = Mock() mock_status_instance.name = "FREE" mock_status.return_value = mock_status_instance request_dict = {"user": "test_user", "passwd": "test_pass", "action": ""} content, status_code = await self.engine_client.check_redundant(request_dict) self.assertEqual(content["code"], 0) self.assertEqual(content["msg"], "ok") self.assertEqual(content["status"], "FREE") self.assertEqual(status_code, 200) async def test_check_redundant_check_load_weight_result(self): """Test check_redundant with check_load_weight_result action.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config # Set up mock update_weight_from_disk_result_list mock_result1 = Mock() mock_result1.value = np.array([1, 2, 3]) mock_result2 = Mock() mock_result2.value = np.array([4, 5, 6]) self.engine_client.update_weight_from_disk_result_list = [mock_result1, mock_result2] request_dict = {"user": "test_user", "passwd": "test_pass", "action": "check_load_weight_result"} content, status_code = await self.engine_client.check_redundant(request_dict) self.assertEqual(content["code"], 0) self.assertEqual(content["msg"], "ok") self.assertIn("data", content) # Code does: update_weight_result.value[0].tolist(), so only first elements self.assertEqual(content["data"], [1, 4]) self.assertEqual(status_code, 200) async def test_check_redundant_invalid_action(self): """Test check_redundant with invalid action.""" mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config request_dict = {"user": "test_user", "passwd": "test_pass", "action": "invalid_action"} content, status_code = await self.engine_client.check_redundant(request_dict) # For invalid action, content remains None and status_code is HTTPStatus.OK self.assertIsNone(content) self.assertEqual(status_code, 200) def test_init_eplb_signals_non_zero_rank(self): """Test init_eplb_signals returns early for non-zero tensor parallel rank.""" mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 1 # Non-zero rank mock_parallel_config.local_data_parallel_id = 0 mock_config = Mock() mock_config.parallel_config = mock_parallel_config # Set fd_config to ensure the method checks the correct config self.engine_client.fd_config = mock_config self.engine_client.config = mock_config # Mock IPCSignal to prevent actual file system calls with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: # Should return early without initializing signals self.engine_client.init_eplb_signals("test_suffix") # Should not create any IPCSignal instances mock_ipcsignal.assert_not_called() # Should return None (implicitly) and not create any signals self.assertFalse(hasattr(self.engine_client, "rearrange_experts_signal")) self.assertFalse(hasattr(self.engine_client, "signal_clear_experts_token_stats_list")) def test_init_eplb_signals_rank_zero_success(self): """Test init_eplb_signals successful initialization for rank 0.""" mock_model_config = Mock() mock_model_config.num_hidden_layers = 12 mock_model_config.moe_num_experts = 8 mock_eplb_config = Mock() mock_eplb_config.redundant_expert_ip_shm_size = 1024 mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_parallel_config.local_data_parallel_id = 2 mock_parallel_config.tensor_parallel_size = 4 mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config self.engine_client.fd_config = mock_config # Also set fd_config for proper access self.engine_client.tensor_parallel_size = 4 # Set this to match the config with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: mock_signal = Mock() mock_ipcsignal.return_value = mock_signal self.engine_client.init_eplb_signals("8080") # Check that IPCSignal was called with correct parameters # Based on the actual implementation: 4 base signals + 4 TP ranks * 5 signals each = 24 total self.assertEqual(mock_ipcsignal.call_count, 24) # 4 TP ranks * 5 signals each + 4 base signals = 24 total # Check that the suffix includes data parallel ID call_args_list = mock_ipcsignal.call_args_list dp_suffix_found = any("8080_dp2" in str(call) for call in call_args_list) self.assertTrue(dp_suffix_found) # Check that all required signal lists were created self.assertEqual(len(self.engine_client.signal_clear_experts_token_stats_list), 4) self.assertEqual(len(self.engine_client.local_experts_token_stats_array_list), 4) self.assertEqual(len(self.engine_client.expert_tokens_stats_array_list), 4) self.assertEqual(len(self.engine_client.signal_update_weight_from_disk_array_list), 4) self.assertEqual(len(self.engine_client.update_weight_from_disk_result_list), 4) # Check that base signals were created self.assertTrue(hasattr(self.engine_client, "rearrange_experts_signal")) self.assertTrue(hasattr(self.engine_client, "rearrange_experts_ips_size_signal")) self.assertTrue(hasattr(self.engine_client, "shm_rearrange_experts_ips_list")) self.assertTrue(hasattr(self.engine_client, "signal_update_weight_from_tensor_array")) def test_init_eplb_signals_array_dimensions(self): """Test init_eplb_signals creates arrays with correct dimensions.""" mock_model_config = Mock() mock_model_config.num_hidden_layers = 6 mock_model_config.moe_num_experts = 4 mock_eplb_config = Mock() mock_eplb_config.redundant_expert_ip_shm_size = 512 mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_parallel_config.local_data_parallel_id = 1 mock_parallel_config.tensor_parallel_size = 2 mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config self.engine_client.tensor_parallel_size = 2 # Set this to match mock_parallel_config.tensor_parallel_size self.engine_client.fd_config = mock_config # Also set fd_config to ensure proper access with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: mock_signal = Mock() mock_ipcsignal.return_value = mock_signal self.engine_client.init_eplb_signals("9090") # Check that IPCSignal was called with arrays of correct shape call_args_list = mock_ipcsignal.call_args_list # Find calls for expert token stats arrays (should be 6x4 shape for 2D arrays) all_experts_token_stats_calls = [call for call in call_args_list if "all_experts_token_stats" in str(call)] local_experts_token_stats_calls = [ call for call in call_args_list if "local_experts_token_stats" in str(call) ] # These should be 2D arrays with shape (6, 4) for call in all_experts_token_stats_calls: array_arg = call[1]["array"] self.assertEqual(array_arg.shape, (6, 4)) # (num_hidden_layers, moe_num_experts) for call in local_experts_token_stats_calls: array_arg = call[1]["array"] self.assertEqual(array_arg.shape, (6, 4)) # (num_hidden_layers, moe_num_experts) # Check that single-element signals have shape (1,) single_element_calls = [ call for call in call_args_list if "rearrange_experts_status" in str(call) or "rearrange_experts_ips_size" in str(call) or "signal_update_weight_from_tensor" in str(call) ] for call in single_element_calls: array_arg = call[1]["array"] self.assertEqual(array_arg.shape, (1,)) # Single element array def test_init_eplb_signals_suffix_format(self): """Test init_eplb_signals uses correct suffix format.""" mock_model_config = Mock() mock_model_config.num_hidden_layers = 4 mock_model_config.moe_num_experts = 2 mock_eplb_config = Mock() mock_eplb_config.redundant_expert_ip_shm_size = 256 mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_parallel_config.local_data_parallel_id = 3 mock_parallel_config.tensor_parallel_size = 1 mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config self.engine_client.fd_config = mock_config # Set fd_config as well # Ensure tensor_parallel_size is set correctly self.engine_client.tensor_parallel_size = 1 with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: mock_signal = Mock() mock_ipcsignal.return_value = mock_signal self.engine_client.init_eplb_signals("7777") # Check suffix format call_args_list = mock_ipcsignal.call_args_list # Check DP suffix dp_calls = [call for call in call_args_list if "rearrange_experts_status" in str(call)] self.assertEqual(len(dp_calls), 1) self.assertEqual(dp_calls[0][1]["suffix"], "7777_dp3") # Check TP suffix for TP rank 0 tp_calls = [call for call in call_args_list if "signal_clear_experts_token_stats" in str(call)] self.assertEqual(len(tp_calls), 1) self.assertEqual(tp_calls[0][1]["suffix"], "7777_dp3_tp0") def test_init_eplb_signals_list_initialization(self): """Test init_eplb_signals properly initializes all signal lists.""" mock_model_config = Mock() mock_model_config.num_hidden_layers = 2 mock_model_config.moe_num_experts = 2 mock_eplb_config = Mock() mock_eplb_config.redundant_expert_ip_shm_size = 128 mock_parallel_config = Mock() mock_parallel_config.tensor_parallel_rank = 0 mock_parallel_config.local_data_parallel_id = 0 mock_parallel_config.tensor_parallel_size = 3 mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = mock_eplb_config mock_config.parallel_config = mock_parallel_config self.engine_client.config = mock_config self.engine_client.tensor_parallel_size = 3 # Set this to match mock_parallel_config.tensor_parallel_size self.engine_client.fd_config = mock_config # Also set fd_config to ensure proper access # Ensure lists start empty self.engine_client.signal_clear_experts_token_stats_list = [] self.engine_client.local_experts_token_stats_array_list = [] self.engine_client.expert_tokens_stats_array_list = [] self.engine_client.signal_update_weight_from_disk_array_list = [] self.engine_client.update_weight_from_disk_result_list = [] with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: mock_signal = Mock() mock_ipcsignal.return_value = mock_signal self.engine_client.init_eplb_signals("6666") # Check that all lists have correct length (3 TP ranks) self.assertEqual(len(self.engine_client.signal_clear_experts_token_stats_list), 3) self.assertEqual(len(self.engine_client.local_experts_token_stats_array_list), 3) self.assertEqual(len(self.engine_client.expert_tokens_stats_array_list), 3) self.assertEqual(len(self.engine_client.signal_update_weight_from_disk_array_list), 3) self.assertEqual(len(self.engine_client.update_weight_from_disk_result_list), 3) async def test_init_iluvatar_platform(self): """Test EngineClient initialization on Iluvatar platform.""" mock_model_config = Mock() mock_model_config.enable_mm = False mock_config = Mock() mock_config.model_config = mock_model_config mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False with ( patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class, patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform, patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal, patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs, patch("os.getenv", return_value="50"), ): mock_platform.is_iluvatar.return_value = True # Iluvatar platform mock_input_processor = Mock() mock_processor_class.return_value = mock_input_processor mock_processor = Mock() mock_input_processor.create_processor.return_value = mock_processor mock_signal_instance = Mock() mock_signal_instance.value = np.array([0]) mock_ipcsignal.return_value = mock_signal_instance mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 client = EngineClient( model_name_or_path="test_model", tokenizer=Mock(), max_model_len=2048, tensor_parallel_size=1, pid=5678, port=8080, limit_mm_per_prompt=5, mm_processor_kwargs={}, config=mock_config, reasoning_parser=None, data_parallel_size=1, enable_logprob=True, workers=1, tool_parser=None, enable_prefix_caching=False, splitwise_role=None, max_processor_cache=0, ) self.assertTrue(client.is_master) # With 1 tensor_parallel_size, should be master even on Iluvatar def test_check_mm_disable_prefix_cache_without_multimodal_data(self): """Test _check_mm_disable_prefix_cache without multimodal data.""" self.engine_client.disable_prefix_mm = True task = {"multimodal_inputs": {"token_type_ids": [0, 0, 0]}} # Sum = 0 result = self.engine_client._check_mm_disable_prefix_cache(task) self.assertFalse(result) async def test_add_requests_multimodal_prefix_cache_error(self): """Test add_requests with multimodal data when prefix cache is enabled.""" self.engine_client.enable_mm = True self.engine_client.enable_prefix_caching = True self.engine_client.disable_prefix_mm = True self.engine_client.data_processor = Mock() self.engine_client.data_processor.process_request_dict = Mock() task = { "request_id": "test_request", "user": "test_user", "multimodal_inputs": {"token_type_ids": [1, 1, 0, 1]}, # Multimodal data present "prompt_token_ids": [1, 2, 3], "max_tokens": 100, } with self.assertRaises(EngineError) as context: await self.engine_client.add_requests(task) self.assertIn("does not support processing requests containing multimodal data", str(context.exception)) self.assertEqual(context.exception.error_code, 400) async def test_add_requests_input_too_long_error(self): """Test add_requests with input length too long.""" self.engine_client.max_model_len = 10 self.engine_client.data_processor = Mock() self.engine_client.data_processor.process_request_dict = Mock() task = { "request_id": "test_request", "user": "test_user", "prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8], # length = 8 "max_tokens": 5, # 8 + 5 = 13 >= 10 "min_tokens": 2, } with self.assertRaises(EngineError) as context: await self.engine_client.add_requests(task) self.assertIn("Input text is too long", str(context.exception)) self.assertIn("input_ids_len (8) + min_tokens(2) >= max_model_len(10)", str(context.exception)) self.assertEqual(context.exception.error_code, 400) @patch("fastdeploy.entrypoints.engine_client.envs.FD_MAX_STOP_SEQS_NUM", 3) async def test_add_requests_stop_seqs_num_exceeds_limit(self): """Test add_requests with stop sequences number exceeding limit.""" self.engine_client.data_processor = Mock() self.engine_client.data_processor.process_request_dict = Mock() task = { "request_id": "test_request", "user": "test_user", "prompt_token_ids": [1, 2, 3], "max_tokens": 10, "stop_seqs_len": [10, 20, 30, 40], # 4 sequences > limit of 3 } with self.assertRaises(EngineError) as context: await self.engine_client.add_requests(task) self.assertIn( "Length of stop ([10, 20, 30, 40]) exceeds the limit max_stop_seqs_num(3)", str(context.exception) ) self.assertIn("Please reduce the number of stop or set a lager max_stop_seqs_num", str(context.exception)) self.assertEqual(context.exception.error_code, 400) @patch("fastdeploy.entrypoints.engine_client.envs.FD_STOP_SEQS_MAX_LEN", 5) async def test_add_requests_single_stop_seq_len_exceeds_limit(self): """Test add_requests with single stop sequence length exceeding limit.""" self.engine_client.data_processor = Mock() self.engine_client.data_processor.process_request_dict = Mock() task = { "request_id": "test_request", "user": "test_user", "prompt_token_ids": [1, 2, 3], "max_tokens": 10, "stop_seqs_len": [3, 10, 2], # 10 > limit of 5 } with self.assertRaises(EngineError) as context: await self.engine_client.add_requests(task) self.assertIn("Length of stop_seqs(10) exceeds the limit stop_seqs_max_len(5)", str(context.exception)) self.assertIn( "Please reduce the length of stop sequences or set a larger stop_seqs_max_len", str(context.exception) ) self.assertEqual(context.exception.error_code, 400) async def test_rearrange_experts_eplb_disabled(self): """Test rearrange_experts when EPLB is disabled.""" # Mock eplb_config with enable_eplb = False mock_eplb_config = Mock() mock_eplb_config.enable_eplb = False mock_config = Mock() mock_config.eplb_config = mock_eplb_config self.engine_client.config = mock_config request_dict = {"user": "test_user", "passwd": "test_pass"} content, status_code = await self.engine_client.rearrange_experts(request_dict) expected_content = {"code": 1, "msg": "redundant expert is disabled"} self.assertEqual(content, expected_content) self.assertEqual(status_code.value, 400) # BAD_REQUEST async def test_rearrange_experts_invalid_credentials(self): """Test rearrange_experts with invalid user/password.""" # Mock eplb_config with enable_eplb = True mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "valid_user" mock_eplb_config.redundant_expert_api_password = "valid_pass" mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config.tensor_parallel_rank = 0 self.engine_client.config = mock_config request_dict = {"user": "invalid_user", "passwd": "invalid_pass"} content, status_code = await self.engine_client.rearrange_experts(request_dict) expected_content = {"code": 1, "msg": "user or passwd is invalid"} self.assertEqual(content, expected_content) self.assertEqual(status_code.value, 401) # UNAUTHORIZED async def test_rearrange_experts_non_rank_zero(self): """Test rearrange_experts from non-zero rank.""" # Mock eplb_config with enable_eplb = True mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config.tensor_parallel_rank = 2 # Non-zero rank self.engine_client.config = mock_config request_dict = {"user": "test_user", "passwd": "test_pass"} content, status_code = await self.engine_client.rearrange_experts(request_dict) expected_content = {"code": 1, "msg": "actual rank 2, expect rank 0"} self.assertEqual(content, expected_content) self.assertEqual(status_code.value, 400) # BAD_REQUEST async def test_rearrange_experts_recv_expert_weight_invalid_data(self): """Test rearrange_experts recv_expert_weight action with invalid data.""" # Mock eplb_config mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config.tensor_parallel_rank = 0 self.engine_client.config = mock_config request_dict = { "user": "test_user", "passwd": "test_pass", "action": "recv_expert_weight", # Missing "data" field } content, status_code = await self.engine_client.rearrange_experts(request_dict) expected_content = {"code": 1, "msg": "data not in request or data is not a list"} self.assertEqual(content, expected_content) self.assertEqual(status_code.value, 400) # BAD_REQUEST async def test_rearrange_experts_invalid_action(self): """Test rearrange_experts with invalid action.""" # Mock eplb_config mock_eplb_config = Mock() mock_eplb_config.enable_eplb = True mock_eplb_config.redundant_expert_api_user = "test_user" mock_eplb_config.redundant_expert_api_password = "test_pass" mock_config = Mock() mock_config.eplb_config = mock_eplb_config mock_config.parallel_config.tensor_parallel_rank = 0 self.engine_client.config = mock_config request_dict = {"user": "test_user", "passwd": "test_pass", "action": "invalid_action"} content, status_code = await self.engine_client.rearrange_experts(request_dict) expected_content = {"code": 1, "msg": "invalid action invalid_action"} self.assertEqual(content, expected_content) self.assertEqual(status_code.value, 400) # BAD_REQUEST if __name__ == "__main__": unittest.main()