Files
FastDeploy/tests/entrypoints/test_engine_client.py
2025-12-04 10:38:51 +08:00

2000 lines
88 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import time
import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import numpy as np
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.utils import EngineError, ParameterError
class DummyConfig(SimpleNamespace):
def __getattr__(self, name):
return None
class TestEngineClient(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""Set up test fixtures before each test method."""
# Create a properly configured tokenizer mock first
mock_tokenizer = Mock()
mock_tokenizer.sp_model = Mock()
mock_tokenizer.sp_model.__len__ = Mock(return_value=1000)
mock_tokenizer.vocab = Mock()
mock_tokenizer.vocab.__len__ = Mock(return_value=1000)
# Add len() method directly to the tokenizer mock
mock_tokenizer.__len__ = Mock(return_value=1000)
# Create a proper ModelConfig mock with enable_mm attribute
mock_model_config = Mock()
mock_model_config.enable_mm = True # Match engine_config.model_config.enable_mm
mock_model_config.enable_logprob = True # Match engine_config.model_config.enable_logprob
mock_model_config.max_model_len = 1024
# Create a mock FDConfig that contains the model_config
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.cache_config = Mock()
mock_config.cache_config.max_processor_cache = 10
mock_config.cache_config.enable_prefix_caching = True
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
mock_config.parallel_config = Mock()
mock_config.parallel_config.tensor_parallel_rank = 0
mock_config.parallel_config.local_data_parallel_id = 0
mock_config.parallel_config.tensor_parallel_size = 1
mock_config.scheduler_config = Mock()
mock_config.scheduler_config.splitwise_role = None
mock_config.limit_mm_per_prompt = 5
mock_config.mm_processor_kwargs = {}
mock_config.tool_parser = None
mock_config.structured_outputs_config = Mock()
mock_config.structured_outputs_config.reasoning_parser = None
# Create mocks for all the external dependencies
mock_input_processor = Mock()
mock_processor = Mock()
mock_processor.tokenizer = mock_tokenizer # Set the tokenizer on the processor
mock_input_processor.create_processor.return_value = mock_processor
# Mock current platform
mock_platform = Mock()
mock_platform.is_iluvatar.return_value = False
mock_platform.max_chips_per_node = 8
# Create mock IPCSignal that behaves properly
mock_ipcsignal = Mock()
mock_signal_instance = Mock()
mock_signal_instance.value = np.array([0])
mock_ipcsignal.return_value = mock_signal_instance
# Mock envs for FD_SUPPORT_MAX_CONNECTIONS
mock_envs = Mock()
mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100
# Mock all the dependencies and external components
with (
patch("fastdeploy.entrypoints.engine_client.IPCSignal"),
patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager"),
patch("fastdeploy.entrypoints.engine_client.InputPreprocessor"),
patch("fastdeploy.entrypoints.engine_client.FileLock"),
patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore"),
patch.multiple(
"fastdeploy.entrypoints.engine_client",
InputPreprocessor=Mock(return_value=mock_input_processor),
ZmqIpcClient=Mock,
IPCSignal=mock_ipcsignal,
StatefulSemaphore=Mock,
DealerConnectionManager=Mock,
FileLock=Mock,
main_process_metrics=Mock(),
current_platform=mock_platform,
envs=mock_envs,
),
patch("fastdeploy.metrics.metrics.main_process_metrics", Mock()),
patch("os.getenv", return_value="50"),
):
self.engine_config = DummyConfig(
model_config=DummyConfig(enable_mm=True, enable_logprob=True, max_model_len=1024),
cache_config=DummyConfig(enable_prefix_caching=True, max_processor_cache=10),
scheduler_config=DummyConfig(splitwise_role="mixed", max_num_seqs=128),
parallel_config=DummyConfig(tensor_parallel_size=1),
structured_outputs_config=DummyConfig(reasoning_parser="reasoning_parser"),
eplb_config=DummyConfig(enable_eplb=True, eplb_max_tokens=1024),
)
# Create EngineClient instance with mocked dependencies
self.engine_client = EngineClient(pid=1234, port=8080, fd_config=mock_config, workers=1)
self.engine_client.zmq_client = MagicMock()
self.engine_client.zmq_client = MagicMock()
def test_engine_client_initialized_by_fd_config(self):
for config_group_name, config_group in self.engine_config.__dict__.items():
for config_name, config_value in config_group.__dict__.items():
if hasattr(self.engine_client, config_name):
# Skip enable_mm, enable_logprob, and enable_prefix_caching checks as they're handled differently in EngineClient
if config_name in ["enable_mm", "enable_logprob", "enable_prefix_caching"]:
continue
assert getattr(self.engine_client, config_name) == config_value
# Check enable_mm separately since it's copied from model_config
assert getattr(self.engine_client, "enable_mm") == self.engine_config.model_config.enable_mm
# Check enable_logprob separately since it's copied from model_config
assert getattr(self.engine_client, "enable_logprob") == self.engine_config.model_config.enable_logprob
# Check enable_prefix_caching separately since it's copied from cache_config
assert (
getattr(self.engine_client, "enable_prefix_caching")
== self.engine_config.cache_config.enable_prefix_caching
)
# Set up mock attributes
self.engine_client.data_processor = Mock()
self.engine_client.data_processor.process_request_dict = Mock()
self.engine_client.zmq_client = Mock()
self.engine_client.zmq_client.send_json = Mock()
self.engine_client.zmq_client.send_pyobj = Mock()
self.engine_client.max_model_len = 1024
self.engine_client.enable_mm = False
self.engine_client.max_logprobs = 20
self.engine_client.enable_logprob = True
self.engine_client.ori_vocab_size = 1000
self.engine_client.enable_prefix_caching = False
self.engine_client.enable_splitwise = False
self.engine_client.disable_prefix_mm = False
# Set up mock attributes for TestEngineClientValidParameters class too
if hasattr(self, "engine_client_valid"):
self.engine_client_valid.zmq_client = Mock()
self.engine_client_valid.zmq_client.send_json = Mock()
self.engine_client_valid.zmq_client.send_pyobj = Mock()
# Mock IPC signals
self.engine_client.worker_healthy_live_signal = Mock()
self.engine_client.worker_healthy_live_signal.value = np.array([time.time()])
self.engine_client.model_weights_status_signal = Mock()
self.engine_client.model_weights_status_signal.value = np.array([0]) # NORMAL
self.engine_client.prefix_tree_status_signal = Mock()
self.engine_client.prefix_tree_status_signal.value = np.array([0]) # NORMAL
self.engine_client.kv_cache_status_signal = Mock()
self.engine_client.kv_cache_status_signal.value = np.array([0]) # NORMAL
# Mock file lock
self.engine_client.clear_update_lock = Mock()
self.engine_client.clear_update_lock.__enter__ = Mock(return_value=None)
self.engine_client.clear_update_lock.__exit__ = Mock(return_value=None)
async def test_add_request(self):
request = {
"request_id": "test-request-id",
"chat_template_kwargs": {"enable_thinking": True},
"prompt_token_ids": [1],
"chat_template": "Hello",
"max_tokens": 20,
"tools": [1],
}
await self.engine_client.add_requests(request)
assert "chat_template" in request["chat_template_kwargs"], "'chat_template' not found in 'chat_template_kwargs"
# assert "tools" in request["chat_template_kwargs"], "'tools' not found in 'chat_template_kwargs'"
assert request["chat_template_kwargs"]["chat_template"] == "Hello"
assert request["tools"] == [1]
# assert request["chat_template_kwargs"]["tools"] == [1]
class TestEngineClientValidParameters(unittest.TestCase):
"""Test cases for EngineClient.valid_parameters method"""
def setUp(self):
"""Set up test fixtures for valid_parameters tests"""
# Mock the dependencies
mock_tokenizer = MagicMock()
mock_tokenizer.sp_model = MagicMock()
mock_tokenizer.sp_model.__len__ = MagicMock(return_value=1000)
mock_tokenizer.vocab = MagicMock()
mock_tokenizer.vocab.__len__ = MagicMock(return_value=1000)
mock_data_processor = MagicMock()
mock_data_processor.tokenizer = mock_tokenizer
mock_model_config = MagicMock()
mock_model_config.enable_mm = False
# Mock config object
mock_config = MagicMock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = MagicMock()
mock_config.eplb_config.enable_eplb = False
mock_config.parallel_config = MagicMock()
mock_config.parallel_config.tensor_parallel_rank = 0
mock_config.parallel_config.local_data_parallel_id = 0
mock_config.parallel_config.tensor_parallel_size = 1 # Add this missing attribute
mock_config.scheduler_config = MagicMock()
mock_config.scheduler_config.splitwise_role = None
mock_config.cache_config = MagicMock() # Add cache_config
mock_config.cache_config.enable_prefix_caching = False
mock_config.cache_config.max_processor_cache = 0
mock_config.limit_mm_per_prompt = 5 # Add this attribute
mock_config.mm_processor_kwargs = {} # Add this attribute
mock_config.structured_outputs_config = MagicMock() # Add this
mock_config.structured_outputs_config.reasoning_parser = None
mock_config.tool_parser = None # Add this attribute
# Mock IPCSignal to avoid file system dependencies
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
mock_ipcsignal.return_value = MagicMock()
with patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore") as mock_semaphore:
mock_semaphore.return_value = MagicMock()
with patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager") as mock_connection_manager:
mock_connection_manager.return_value = MagicMock()
with patch("fastdeploy.entrypoints.engine_client.FileLock") as mock_filelock:
mock_filelock.return_value = MagicMock()
with patch("fastdeploy.config.ModelConfig") as mock_model_config_class:
mock_model_config_class.return_value = mock_model_config
with patch(
"fastdeploy.entrypoints.engine_client.InputPreprocessor"
) as mock_input_processor:
mock_input_processor_instance = MagicMock()
mock_input_processor_instance.create_processor.return_value = mock_data_processor
mock_input_processor.return_value = mock_input_processor_instance
# Create EngineClient with minimal required parameters
self.engine_client = EngineClient(
pid=1234,
port=8080,
fd_config=mock_config,
workers=1,
)
# Set up mock attributes for TestEngineClientValidParameters class
self.engine_client.zmq_client = Mock()
self.engine_client.zmq_client.send_json = Mock()
self.engine_client.zmq_client.send_pyobj = Mock()
self.engine_client.max_logprobs = 20
self.engine_client.enable_logprob = True
self.engine_client.ori_vocab_size = 1000
self.engine_client.enable_prefix_caching = False
self.engine_client.enable_splitwise = False
self.engine_client.disable_prefix_mm = False
self.engine_client.max_model_len = 1024
self.engine_client.enable_mm = False
self.engine_client.config = mock_config
self.engine_client.max_chips_per_node = 8
self.engine_client.tensor_parallel_size = 1
self.engine_client.is_master = True
self.engine_client.worker_healthy_live_signal = Mock()
self.engine_client.worker_healthy_live_signal.value = np.array([0])
self.engine_client.model_weights_status_signal = Mock()
self.engine_client.model_weights_status_signal.value = np.array([0])
self.engine_client.clear_update_lock = Mock()
self.engine_client.clear_update_lock.__enter__ = Mock(return_value=None)
self.engine_client.clear_update_lock.__exit__ = Mock(return_value=None)
self.engine_client.kv_cache_status_signal = Mock()
self.engine_client.kv_cache_status_signal.value = np.array([0])
self.engine_client.prefix_tree_status_signal = Mock()
self.engine_client.prefix_tree_status_signal.value = np.array([0])
def test_max_logprobs_valid_values(self):
"""Test valid max_logprobs values"""
# Test positive max_logprobs
self.engine_client.max_logprobs = 20
data = {"request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# Test -1 (unlimited)
self.engine_client.max_logprobs = -1
data = {"request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
def test_max_logprobs_invalid_values(self):
"""Test invalid max_logprobs values"""
# Test negative value less than -1
self.engine_client.max_logprobs = -2
data = {"request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("max_logprobs", str(context.exception))
self.assertIn("must be >= -1", str(context.exception))
self.assertIn("got -2", str(context.exception))
def test_max_logprobs_exceeds_vocab_size(self):
"""Test max_logprobs exceeding vocab_size"""
self.engine_client.max_logprobs = 1500
self.engine_client.ori_vocab_size = 1000
data = {"request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("max_logprobs", str(context.exception))
self.assertIn("must be <= vocab_size", str(context.exception))
self.assertIn("1000", str(context.exception))
self.assertIn("got 1500", str(context.exception))
def test_max_logprobs_unlimited(self):
"""Test max_logprobs = -1 (unlimited) sets to ori_vocab_size"""
self.engine_client.max_logprobs = -1
self.engine_client.ori_vocab_size = 1000
data = {"request_id": "test"}
# This should not raise and internally max_logprobs should be set to ori_vocab_size
self.engine_client.valid_parameters(data) # Should not raise
# The actual max_logprobs value should be set to ori_vocab_size internally
self.assertEqual(self.engine_client.max_logprobs, -1) # Original value remains unchanged
def test_prompt_logprobs_valid_values(self):
"""Test valid prompt_logprobs values"""
self.engine_client.max_logprobs = 20
self.engine_client.enable_logprob = True
# Test valid positive value with FD_USE_GET_SAVE_OUTPUT_V1=1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
data = {"prompt_logprobs": 10, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# Test -1 (unlimited) with FD_USE_GET_SAVE_OUTPUT_V1=1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
self.engine_client.max_logprobs = -1
data = {"prompt_logprobs": -1, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# Test None (default)
data = {"request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
def test_prompt_logprobs_unlimited_sets_to_vocab_size(self):
"""Test prompt_logprobs = -1 sets to ori_vocab_size"""
self.engine_client.max_logprobs = -1 # Set to unlimited to allow prompt_logprobs = -1
self.engine_client.enable_logprob = True
self.engine_client.ori_vocab_size = 1000
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
data = {"prompt_logprobs": -1, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# prompt_logprobs should be set to ori_vocab_size internally
def test_prompt_logprobs_disabled_when_fd_use_get_save_output_v1_disabled(self):
"""Test prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
data = {"prompt_logprobs": 10, "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(context.exception)
)
def test_prompt_logprobs_disabled_logprob(self):
"""Test prompt_logprobs when logprob is disabled"""
self.engine_client.enable_logprob = False
data = {"prompt_logprobs": 10, "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("`enable_logprob` is disabled, please enable it in startup config.", str(context.exception))
def test_prompt_logprobs_disabled_when_prefix_caching_enabled(self):
"""Test prompt_logprobs when prefix caching is enabled"""
self.engine_client.enable_prefix_caching = True
self.engine_client.enable_logprob = True
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
data = {"prompt_logprobs": 10, "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("prompt_logprobs is not support when prefix caching is enabled", str(context.exception))
def test_prompt_logprobs_invalid_values(self):
"""Test invalid prompt_logprobs values"""
self.engine_client.enable_logprob = True
# Test negative value less than -1 with FD_USE_GET_SAVE_OUTPUT_V1=1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
data = {"prompt_logprobs": -2, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("prompt_logprobs", str(context.exception))
self.assertIn("must be a non-negative value or -1", str(context.exception))
self.assertIn("current value is -2", str(context.exception))
def test_prompt_logprobs_exceeds_max_logprobs(self):
"""Test prompt_logprobs exceeding max_logprobs"""
self.engine_client.max_logprobs = 10
self.engine_client.enable_logprob = True
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
data = {"prompt_logprobs": 15, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("prompt_logprobs", str(context.exception))
self.assertIn("exceeds maximum allowed value", str(context.exception))
self.assertIn("15", str(context.exception))
self.assertIn("10", str(context.exception))
def test_top_logprobs_validation_with_fd_use_get_save_output_v1_enabled(self):
"""Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
self.engine_client.max_logprobs = 20
self.engine_client.enable_logprob = True
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
# Test -1 (unlimited) - should set to ori_vocab_size, but need max_logprobs also to be -1
self.engine_client.max_logprobs = -1 # Set to unlimited to allow top_logprobs = -1
data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# Reset max_logprobs for other tests
self.engine_client.max_logprobs = 20
# Test valid positive value
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# Test value less than -1 - should raise ValueError
data = {"logprobs": True, "top_logprobs": -2, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("must be a non-negative value or -1", str(context.exception))
self.assertIn("current value is -2", str(context.exception))
# Test value exceeding max_logprobs - should raise ValueError
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("exceeds maximum allowed value", str(context.exception))
def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self):
"""Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
self.engine_client.max_logprobs = 20
self.engine_client.enable_logprob = True
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
# Test negative value - should raise ValueError
data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
self.assertIn("current value is -1", str(context.exception))
# Test value > 20 - should raise ValueError
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn(
"Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception)
)
# Test valid value
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
def test_top_logprobs_disabled_logprob(self):
"""Test top_logprobs when logprob is disabled"""
self.engine_client.enable_logprob = False
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("disabled", str(context.exception))
def test_top_logprobs_invalid_type(self):
"""Test top_logprobs with invalid type"""
self.engine_client.enable_logprob = True
# Test with string type
data = {"logprobs": True, "top_logprobs": "10", "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("top_logprobs", str(context.exception))
self.assertIn("Invalid type", str(context.exception))
self.assertIn("expected int", str(context.exception))
def test_logprobs_invalid_type(self):
"""Test logprobs with invalid type"""
self.engine_client.enable_logprob = True
# Test with string type
data = {"logprobs": "true", "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("logprobs", str(context.exception))
self.assertIn("Invalid type", str(context.exception))
def test_logprobs_disabled(self):
"""Test logprobs when logprob is disabled"""
self.engine_client.enable_logprob = False
# Test with logprobs=True
data = {"logprobs": True, "request_id": "test"}
with self.assertRaises(ParameterError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("disabled", str(context.exception))
def test_unlimited_max_logprobs_with_prompt_logprobs(self):
"""Test unlimited max_logprobs (-1) with prompt_logprobs"""
self.engine_client.max_logprobs = -1 # Unlimited
self.engine_client.enable_logprob = True
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
# Should allow any prompt_logprobs value
data = {"prompt_logprobs": 1000, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
def test_unlimited_max_logprobs_with_top_logprobs(self):
"""Test unlimited max_logprobs (-1) with top_logprobs"""
self.engine_client.max_logprobs = -1 # Unlimited
self.engine_client.enable_logprob = True
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
# Should allow any top_logprobs value
data = {"logprobs": True, "top_logprobs": 1000, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
def test_edge_case_zero_values(self):
"""Test edge cases with zero values"""
self.engine_client.max_logprobs = 20
self.engine_client.enable_logprob = True
# Test prompt_logprobs = 0 with FD_USE_GET_SAVE_OUTPUT_V1=1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
data = {"prompt_logprobs": 0, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
# Test top_logprobs = 0 with FD_USE_GET_SAVE_OUTPUT_V1=0
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
data = {"logprobs": True, "top_logprobs": 0, "request_id": "test"}
self.engine_client.valid_parameters(data) # Should not raise
def test_valid_parameters(self):
request = {
"request_id": "test-request-id",
"chat_template_kwargs": {"enable_thinking": True},
"prompt_token_ids": [1],
"chat_template": "Hello",
"max_tokens": 20,
"tools": [1],
"temperature": 0,
}
self.engine_client.valid_parameters(request)
assert request["temperature"] == 1e-6
async def test_init_basic_parameters(self):
"""Test EngineClient initialization with basic parameters."""
# Create a proper ModelConfig mock with enable_mm attribute
mock_model_config = Mock()
mock_model_config.enable_mm = False
# Create mocks for all the external dependencies
mock_input_processor = Mock()
mock_processor = Mock()
mock_input_processor.create_processor.return_value = mock_processor
# Mock current platform
mock_platform = Mock()
mock_platform.is_iluvatar.return_value = False
# Create mock IPCSignal that behaves properly
mock_ipcsignal = Mock()
mock_signal_instance = Mock()
mock_signal_instance.value = np.array([0])
mock_ipcsignal.return_value = mock_signal_instance
# Mock envs for FD_SUPPORT_MAX_CONNECTIONS
mock_envs = Mock()
mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100
with (
patch.multiple(
"fastdeploy.entrypoints.engine_client",
InputPreprocessor=Mock(return_value=mock_input_processor),
current_platform=mock_platform,
IPCSignal=mock_ipcsignal,
StatefulSemaphore=Mock,
DealerConnectionManager=Mock,
FileLock=Mock,
work_process_metrics=Mock(),
envs=mock_envs,
),
patch("os.getenv", return_value="50"),
):
# Create a mock config for this test
mock_config = Mock()
mock_config.model_config = Mock()
mock_config.model_config.enable_mm = False
client = EngineClient(
model_name_or_path="test_model",
tokenizer=Mock(),
max_model_len=2048,
tensor_parallel_size=2,
pid=5678,
port=9090,
limit_mm_per_prompt=3,
mm_processor_kwargs={"test": "value"},
config=mock_config,
reasoning_parser=None,
data_parallel_size=1,
enable_logprob=False,
workers=2,
tool_parser=None,
enable_prefix_caching=True,
splitwise_role="master",
max_processor_cache=100,
)
self.assertEqual(client.max_model_len, 2048)
self.assertEqual(client.enable_logprob, False)
self.assertEqual(client.enable_prefix_caching, True)
self.assertEqual(client.enable_splitwise, True)
async def test_format_and_add_data_without_request_id(self):
"""Test format_and_add_data adds request_id when missing."""
prompts = {"prompt_token_ids": [1, 2, 3], "max_tokens": 50}
with patch.object(self.engine_client, "add_requests") as mock_add:
mock_add.return_value = None
result = await self.engine_client.format_and_add_data(prompts)
self.assertIn("request_id", prompts)
self.assertEqual(result, prompts["prompt_token_ids"])
mock_add.assert_called_once_with(prompts)
async def test_format_and_add_data_with_max_tokens_default(self):
"""Test format_and_add_data sets default max_tokens when missing."""
prompts = {"request_id": "test-id", "prompt_token_ids": [1, 2, 3]}
with patch.object(self.engine_client, "add_requests") as mock_add:
mock_add.return_value = None
await self.engine_client.format_and_add_data(prompts)
self.assertEqual(prompts["max_tokens"], self.engine_client.max_model_len - 1)
async def test_check_mm_disable_prefix_cache_with_disabled_cache(self):
"""Test _check_mm_disable_prefix_cache when prefix cache is disabled."""
self.engine_client.disable_prefix_mm = False
task = {"multimodal_inputs": {"token_type_ids": [1, 2, 3]}}
result = self.engine_client._check_mm_disable_prefix_cache(task)
self.assertFalse(result)
async def test_check_mm_disable_prefix_cache_with_no_multimodal_data(self):
"""Test _check_mm_disable_prefix_cache with no multimodal inputs."""
self.engine_client.disable_prefix_mm = True
task = {"multimodal_inputs": []}
result = self.engine_client._check_mm_disable_prefix_cache(task)
self.assertFalse(result)
async def test_check_mm_disable_prefix_cache_with_multimodal_data(self):
"""Test _check_mm_disable_prefix_cache detects multimodal data."""
self.engine_client.disable_prefix_mm = True
task = {"multimodal_inputs": {"token_type_ids": [1, 0, 2]}}
result = self.engine_client._check_mm_disable_prefix_cache(task)
self.assertTrue(result)
async def test_add_requests_successful_processing(self):
"""Test successful request processing in add_requests."""
task = {
"request_id": "test-id",
"chat_template_kwargs": {"existing": "value"},
"chat_template": "test_template",
"prompt_token_ids": [1, 2, 3, 4, 5],
"max_tokens": 100,
"min_tokens": 1,
"messages": "test message",
}
self.engine_client.data_processor.process_request_dict = Mock()
with patch.object(self.engine_client, "_send_task") as mock_send:
await self.engine_client.add_requests(task)
self.assertEqual(task["chat_template_kwargs"]["chat_template"], "test_template")
self.assertEqual(task["prompt_token_ids_len"], 5)
self.assertNotIn("messages", task)
mock_send.assert_called_once()
async def test_add_requests_with_coroutine_processor(self):
"""Test add_requests with async processor."""
task = {"request_id": "test-id", "prompt_token_ids": [1, 2, 3], "max_tokens": 100}
async_mock = AsyncMock()
self.engine_client.data_processor.process_request_dict = async_mock
with patch.object(self.engine_client, "_send_task"):
await self.engine_client.add_requests(task)
async_mock.assert_called_once()
async def test_add_requests_with_multimodal_prefix_cache_error(self):
"""Test add_requests raises error for multimodal data with prefix cache."""
self.engine_client.enable_mm = True
self.engine_client.enable_prefix_caching = True
self.engine_client.disable_prefix_mm = True
task = {
"request_id": "test-id",
"prompt_token_ids": [1, 2, 3],
"multimodal_inputs": {"token_type_ids": [1, 0, 1]},
}
with self.assertRaises(Exception): # EngineError
await self.engine_client.add_requests(task)
async def test_add_requests_input_length_validation_error(self):
"""Test add_requests validation for input length."""
task = {"request_id": "test-id", "prompt_token_ids": list(range(1024)), "min_tokens": 1} # At max length
with self.assertRaises(Exception): # EngineError
await self.engine_client.add_requests(task)
async def test_add_requests_stop_sequences_validation(self):
"""Test add_requests validation for stop sequences."""
task = {
"request_id": "test-id",
"prompt_token_ids": [1, 2, 3],
"stop_seqs_len": list(range(25)), # Exceeds default limit
}
with patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs:
mock_envs.FD_MAX_STOP_SEQS_NUM = 20
mock_envs.FD_STOP_SEQS_MAX_LEN = 100
with self.assertRaises(Exception): # EngineError
await self.engine_client.add_requests(task)
async def test_add_requests_with_n_parameter_multiple_requests(self):
"""Test add_requests with n parameter for multiple requests."""
task = {"request_id": "test-id_1", "prompt_token_ids": [1, 2, 3], "n": 3, "max_tokens": 100}
with patch.object(self.engine_client, "_send_task") as mock_send:
await self.engine_client.add_requests(task)
# Should send 3 tasks with indices 3, 4, 5 (1*3 to (1+1)*3)
self.assertEqual(mock_send.call_count, 3)
def test_send_task_without_multimodal(self):
"""Test _send_task for non-multimodal content."""
self.engine_client.enable_mm = False
task = {"test": "data"}
self.engine_client._send_task(task)
self.engine_client.zmq_client.send_json.assert_called_once_with(task)
def test_send_task_with_multimodal(self):
"""Test _send_task for multimodal content."""
self.engine_client.enable_mm = True
task = {"test": "multimodal_data"}
self.engine_client._send_task(task)
self.engine_client.zmq_client.send_pyobj.assert_called_once_with(task)
def test_valid_parameters_max_tokens_valid(self):
"""Test valid_parameters accepts valid max_tokens."""
data = {"max_tokens": 100}
# Should not raise exception
self.engine_client.valid_parameters(data)
def test_valid_parameters_max_tokens_too_small(self):
"""Test valid_parameters rejects max_tokens < 1."""
data = {"max_tokens": 0}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_max_tokens_too_large(self):
"""Test valid_parameters rejects max_tokens >= max_model_len."""
data = {"max_tokens": 2048} # Equal to max_model_len, should raise exception
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_reasoning_max_tokens_adjustment(self):
"""Test valid_parameters adjusts reasoning_max_tokens when needed."""
data = {"max_tokens": 50, "reasoning_max_tokens": 100, "request_id": "test-id"} # Larger than max_tokens
with patch("fastdeploy.entrypoints.engine_client.api_server_logger") as mock_logger:
self.engine_client.valid_parameters(data)
self.assertEqual(data["reasoning_max_tokens"], 50)
mock_logger.warning.assert_called_once()
def test_valid_parameters_temperature_zero_adjustment(self):
"""Test valid_parameters adjusts zero temperature."""
data = {"temperature": 0}
self.engine_client.valid_parameters(data)
self.assertEqual(data["temperature"], 1e-6)
def test_valid_parameters_logprobs_disabled_when_enabled(self):
"""Test valid_parameters rejects logprobs when disabled."""
self.engine_client.enable_logprob = False
data = {"logprobs": True}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_logprobs_with_invalid_type(self):
"""Test valid_parameters rejects invalid logprobs type."""
data = {"logprobs": "invalid"}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_top_logprobs_disabled(self):
"""Test valid_parameters rejects top_logprobs when disabled."""
self.engine_client.enable_logprob = False
data = {"logprobs": True, "top_logprobs": 5}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_top_logprobs_invalid_type(self):
"""Test valid_parameters rejects invalid top_logprobs type."""
self.engine_client.enable_logprob = True
data = {"logprobs": True, "top_logprobs": "invalid"}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_top_logprobs_negative(self):
"""Test valid_parameters rejects negative top_logprobs."""
self.engine_client.enable_logprob = True
data = {"logprobs": True, "top_logprobs": -1}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_top_logprobs_too_large(self):
"""Test valid_parameters rejects top_logprobs > 20."""
self.engine_client.enable_logprob = True
data = {"logprobs": True, "top_logprobs": 25}
with self.assertRaises(Exception): # ParameterError
self.engine_client.valid_parameters(data)
def test_valid_parameters_top_logprobs_valid(self):
"""Test valid_parameters accepts valid top_logprobs."""
self.engine_client.enable_logprob = True
data = {"logprobs": True, "top_logprobs": 10}
# Should not raise exception
self.engine_client.valid_parameters(data)
def test_check_health_healthy(self):
"""Test check_health returns healthy status."""
self.engine_client.worker_healthy_live_signal.value = np.array([time.time()])
result, message = self.engine_client.check_health()
self.assertTrue(result)
self.assertEqual(message, "")
def test_check_health_unhealthy_timeout(self):
"""Test check_health returns unhealthy due to timeout."""
# Set signal to old time (more than 30 seconds ago)
old_time = time.time() - 60
self.engine_client.worker_healthy_live_signal.value = np.array([old_time])
result, message = self.engine_client.check_health(time_interval_threashold=30)
self.assertFalse(result)
self.assertEqual(message, "Worker Service Not Healthy")
def test_is_workers_alive_normal(self):
"""Test is_workers_alive returns True when weights are normal."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.NORMAL = 0
self.engine_client.model_weights_status_signal.value = np.array([0])
result, message = self.engine_client.is_workers_alive()
self.assertTrue(result)
self.assertEqual(message, "")
def test_is_workers_alive_no_weights(self):
"""Test is_workers_alive returns False when no weights."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.NORMAL = 0
self.engine_client.model_weights_status_signal.value = np.array([1])
result, message = self.engine_client.is_workers_alive()
self.assertFalse(result)
self.assertEqual(message, "No model weight enabled")
def test_update_model_weight_already_normal(self):
"""Test update_model_weight when weights are already normal."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.NORMAL = 0
self.engine_client.model_weights_status_signal.value = np.array([0])
result, message = self.engine_client.update_model_weight()
self.assertTrue(result)
self.assertEqual(message, "")
def test_update_model_weight_already_updating(self):
"""Test update_model_weight when already updating."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.NORMAL = 0
mock_status.UPDATING = 1
self.engine_client.model_weights_status_signal.value = np.array([1])
result, message = self.engine_client.update_model_weight()
self.assertFalse(result)
self.assertEqual(message, "worker is updating model weight already")
def test_update_model_weight_clearing(self):
"""Test update_model_weight when clearing weights."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.NORMAL = 0
mock_status.CLEARING = -1
self.engine_client.model_weights_status_signal.value = np.array([-1])
result, message = self.engine_client.update_model_weight()
self.assertFalse(result)
self.assertEqual(message, "worker is clearing model weight, cannot update now")
def test_update_model_weight_timeout(self):
"""Test update_model_weight timeout scenario."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
with patch("fastdeploy.entrypoints.engine_client.KVCacheStatus") as mock_kv_status:
with patch("fastdeploy.entrypoints.engine_client.PrefixTreeStatus") as mock_prefix_status:
mock_status.NORMAL = 0
mock_status.UPDATING = 1
mock_status.CLEARED = -2
mock_kv_status.NORMAL = 0
mock_kv_status.UPDATING = 1
mock_kv_status.CLEARED = -2
mock_prefix_status.NORMAL = 0
mock_prefix_status.UPDATING = 1
mock_prefix_status.CLEARED = -2
self.engine_client.enable_prefix_caching = True
# Start with CLEARED status to enter the updating loop
self.engine_client.model_weights_status_signal.value = np.array([-2])
self.engine_client.kv_cache_status_signal.value = np.array([-2]) # Start as CLEARED
self.engine_client.prefix_tree_status_signal.value = np.array([-2]) # Start as CLEARED
result, message = self.engine_client.update_model_weight(timeout=1)
self.assertFalse(result)
self.assertEqual(message, "Update model weight timeout")
def test_clear_load_weight_already_cleared(self):
"""Test clear_load_weight when weights are already cleared."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.CLEARED = -2
self.engine_client.model_weights_status_signal.value = np.array([-2])
result, message = self.engine_client.clear_load_weight()
self.assertTrue(result)
self.assertEqual(message, "")
def test_clear_load_weight_already_clearing(self):
"""Test clear_load_weight when already clearing."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.CLEARED = -2
mock_status.CLEARING = -1
self.engine_client.model_weights_status_signal.value = np.array([-1])
result, message = self.engine_client.clear_load_weight()
self.assertFalse(result)
self.assertEqual(message, "worker is clearing model weight already")
def test_clear_load_weight_updating(self):
"""Test clear_load_weight when updating weights."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
mock_status.CLEARED = -2
mock_status.CLEARING = -1
mock_status.UPDATING = 1
self.engine_client.model_weights_status_signal.value = np.array([1])
result, message = self.engine_client.clear_load_weight()
self.assertFalse(result)
self.assertEqual(message, "worker is updating model weight, cannot clear now")
def test_clear_load_weight_timeout(self):
"""Test clear_load_weight timeout scenario."""
with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status:
with patch("fastdeploy.entrypoints.engine_client.KVCacheStatus") as mock_kv_status:
with patch("fastdeploy.entrypoints.engine_client.PrefixTreeStatus") as mock_prefix_status:
mock_status.NORMAL = 0
mock_status.CLEARED = -2
mock_status.CLEARING = -1
mock_kv_status.CLEARED = -2
mock_kv_status.CLEARING = -1
mock_prefix_status.CLEARED = -2
mock_prefix_status.CLEARING = -1
self.engine_client.enable_prefix_caching = True
# Start with NORMAL status to enter the clearing loop
self.engine_client.model_weights_status_signal.value = np.array([0])
self.engine_client.kv_cache_status_signal.value = np.array([0]) # Start as NORMAL
self.engine_client.prefix_tree_status_signal.value = np.array([0]) # Start as NORMAL
result, message = self.engine_client.clear_load_weight(timeout=1)
self.assertFalse(result)
self.assertEqual(message, "Clear model weight timeout")
def test_check_model_weight_status(self):
"""Test check_model_weight_status returns correct status."""
# Status < 0 indicates abnormal
self.engine_client.model_weights_status_signal.value = np.array([-1])
result = self.engine_client.check_model_weight_status()
self.assertTrue(result)
# Status >= 0 indicates normal
self.engine_client.model_weights_status_signal.value = np.array([0])
result = self.engine_client.check_model_weight_status()
self.assertFalse(result)
def test_create_zmq_client(self):
"""Test create_zmq_client method."""
mock_zmq_client = Mock()
with patch("fastdeploy.entrypoints.engine_client.ZmqIpcClient", return_value=mock_zmq_client) as mock_zmq:
self.engine_client.create_zmq_client("test_model", "test_mode")
mock_zmq.assert_called_once_with("test_model", "test_mode")
mock_zmq_client.connect.assert_called_once()
self.assertEqual(self.engine_client.zmq_client, mock_zmq_client)
async def test_init_with_multimodal_prefix_cache(self):
"""Test EngineClient initialization with multimodal prefix cache enabled."""
mock_model_config = Mock()
mock_model_config.enable_mm = True
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
with (
patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class,
patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform,
patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal,
patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs,
patch("os.getenv", return_value="50"),
patch("fastdeploy.cache_manager.cache_data.is_mm_model_disable_prefix_cache", return_value=True),
):
mock_platform.is_iluvatar.return_value = False
mock_input_processor = Mock()
mock_processor_class.return_value = mock_input_processor
mock_processor = Mock()
mock_input_processor.create_processor.return_value = mock_processor
mock_signal_instance = Mock()
mock_signal_instance.value = np.array([0])
mock_ipcsignal.return_value = mock_signal_instance
mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100
client = EngineClient(
model_name_or_path="test_model",
tokenizer=Mock(),
max_model_len=2048,
tensor_parallel_size=1,
pid=5678,
port=8080,
limit_mm_per_prompt=5,
mm_processor_kwargs={},
config=mock_config,
reasoning_parser=None,
data_parallel_size=1,
enable_logprob=True,
workers=1,
tool_parser=None,
enable_prefix_caching=True, # Enable prefix caching
splitwise_role=None,
max_processor_cache=0,
)
self.assertTrue(client.enable_mm)
self.assertTrue(client.enable_prefix_caching)
self.assertTrue(client.disable_prefix_mm)
async def test_init_as_worker_node(self):
"""Test EngineClient initialization as worker node (not master)."""
mock_model_config = Mock()
mock_model_config.enable_mm = False
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
with (
patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class,
patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform,
patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal,
patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs,
patch("os.getenv", return_value="50"),
):
mock_platform.is_iluvatar.return_value = False
mock_platform.max_chips_per_node = 8
mock_input_processor = Mock()
mock_processor_class.return_value = mock_input_processor
mock_processor = Mock()
mock_input_processor.create_processor.return_value = mock_processor
mock_signal_instance = Mock()
mock_signal_instance.value = np.array([0])
mock_ipcsignal.return_value = mock_signal_instance
mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100
# Use tensor_parallel_size > max_chips_per_node to make it a worker
client = EngineClient(
model_name_or_path="test_model",
tokenizer=Mock(),
max_model_len=2048,
tensor_parallel_size=16, # Large number to make it a worker
pid=5678,
port=8080,
limit_mm_per_prompt=5,
mm_processor_kwargs={},
config=mock_config,
reasoning_parser=None,
data_parallel_size=1,
enable_logprob=True,
workers=1,
tool_parser=None,
enable_prefix_caching=False,
splitwise_role=None,
max_processor_cache=0,
)
self.assertFalse(client.is_master)
async def test_format_and_add_data(self):
"""Test format_and_add_data method."""
prompts = {"prompt_token_ids": [1, 2, 3], "max_tokens": 50}
with patch.object(self.engine_client, "add_requests") as mock_add:
mock_add.return_value = None
await self.engine_client.format_and_add_data(prompts)
mock_add.assert_called_once()
call_args = mock_add.call_args[0][0]
self.assertIn("request_id", call_args)
self.assertEqual(call_args["prompt_token_ids"], [1, 2, 3])
self.assertEqual(call_args["max_tokens"], 50)
async def test_rearrange_experts_disabled(self):
"""Test rearrange_experts when EPLB is disabled."""
mock_config = Mock()
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
self.engine_client.config = mock_config
request_dict = {"user": "test", "passwd": "test"}
content, status_code = await self.engine_client.rearrange_experts(request_dict)
self.assertEqual(content["code"], 1)
self.assertEqual(content["msg"], "redundant expert is disabled")
self.assertEqual(status_code, 400)
async def test_get_per_expert_tokens_stats_disabled(self):
"""Test get_per_expert_tokens_stats when EPLB is disabled."""
mock_config = Mock()
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
self.engine_client.config = mock_config
request_dict = {"user": "test", "passwd": "test"}
content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict)
self.assertEqual(content["code"], 1)
self.assertEqual(content["msg"], "redundant expert is disabled")
self.assertEqual(status_code, 400)
async def test_get_per_expert_tokens_stats_invalid_auth(self):
"""Test get_per_expert_tokens_stats with invalid authentication."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "correct_user"
mock_eplb_config.redundant_expert_api_password = "correct_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
request_dict = {"user": "wrong_user", "passwd": "wrong_pass"}
content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict)
self.assertEqual(content["code"], 1)
self.assertEqual(content["msg"], "user or passwd is invalid")
self.assertEqual(status_code, 401)
async def test_get_per_expert_tokens_stats_success(self):
"""Test get_per_expert_tokens_stats successful response."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
# Set up mock arrays
mock_local_stats = Mock()
mock_local_stats.value = np.array([1, 2, 3])
self.engine_client.local_experts_token_stats_array_list = [mock_local_stats]
self.engine_client.signal_clear_experts_token_stats_list = []
request_dict = {"user": "test_user", "passwd": "test_pass"}
content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict)
self.assertEqual(content["code"], 0)
self.assertEqual(content["msg"], "ok")
self.assertIn("data", content)
self.assertEqual(content["data"], [[1, 2, 3]])
self.assertEqual(status_code, 200)
async def test_get_per_expert_tokens_stats_clear_stat(self):
"""Test get_per_expert_tokens_stats with clear_stat flag."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
# Set up mock arrays and signals
mock_clear_signal = Mock()
mock_clear_signal.value = np.array([0])
self.engine_client.signal_clear_experts_token_stats_list = [mock_clear_signal]
mock_local_stats = Mock()
mock_local_stats.value = np.array([1, 2, 3])
self.engine_client.local_experts_token_stats_array_list = [mock_local_stats]
request_dict = {"user": "test_user", "passwd": "test_pass", "clear_stat": True}
content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict)
self.assertEqual(content["code"], 0)
self.assertEqual(content["msg"], "ok")
self.assertEqual(mock_clear_signal.value[0], 1) # Clear signal should be set
self.assertEqual(status_code, 200)
async def test_check_redundant_disabled(self):
"""Test check_redundant when EPLB is disabled."""
mock_config = Mock()
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
self.engine_client.config = mock_config
request_dict = {"user": "test", "passwd": "test"}
content, status_code = await self.engine_client.check_redundant(request_dict)
self.assertEqual(content["code"], 1)
self.assertEqual(content["msg"], "redundant expert is disabled")
self.assertEqual(status_code, 400)
async def test_check_redundant_invalid_auth(self):
"""Test check_redundant with invalid authentication."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "correct_user"
mock_eplb_config.redundant_expert_api_password = "correct_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
request_dict = {"user": "wrong_user", "passwd": "wrong_pass"}
content, status_code = await self.engine_client.check_redundant(request_dict)
self.assertEqual(content["code"], 1)
self.assertEqual(content["msg"], "user or passwd is invalid")
self.assertEqual(status_code, 401)
async def test_check_redundant_wrong_rank(self):
"""Test check_redundant with wrong tensor parallel rank."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 1 # Not rank 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
request_dict = {"user": "test_user", "passwd": "test_pass"}
content, status_code = await self.engine_client.check_redundant(request_dict)
self.assertEqual(content["code"], 1)
self.assertIn("actual rank 1, expect rank 0", content["msg"])
self.assertEqual(status_code, 400)
async def test_check_redundant_status_unknown(self):
"""Test check_redundant with unknown status (invalid signal value)."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
self.engine_client.rearrange_experts_signal = Mock()
self.engine_client.rearrange_experts_signal.value = np.array([999]) # Invalid status
with patch("fastdeploy.entrypoints.engine_client.RearrangeExpertStatus") as mock_status:
mock_status.side_effect = Exception("Invalid status")
request_dict = {"user": "test_user", "passwd": "test_pass", "action": ""}
content, status_code = await self.engine_client.check_redundant(request_dict)
self.assertEqual(content["code"], 0)
self.assertEqual(content["msg"], "ok")
self.assertEqual(content["status"], "unknown") # Should fallback to unknown
self.assertEqual(status_code, 200)
async def test_check_redundant_status_known(self):
"""Test check_redundant with known status."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
self.engine_client.rearrange_experts_signal = Mock()
self.engine_client.rearrange_experts_signal.value = np.array([0]) # FREE status
with patch("fastdeploy.entrypoints.engine_client.RearrangeExpertStatus") as mock_status:
mock_status_instance = Mock()
mock_status_instance.name = "FREE"
mock_status.return_value = mock_status_instance
request_dict = {"user": "test_user", "passwd": "test_pass", "action": ""}
content, status_code = await self.engine_client.check_redundant(request_dict)
self.assertEqual(content["code"], 0)
self.assertEqual(content["msg"], "ok")
self.assertEqual(content["status"], "FREE")
self.assertEqual(status_code, 200)
async def test_check_redundant_check_load_weight_result(self):
"""Test check_redundant with check_load_weight_result action."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
# Set up mock update_weight_from_disk_result_list
mock_result1 = Mock()
mock_result1.value = np.array([1, 2, 3])
mock_result2 = Mock()
mock_result2.value = np.array([4, 5, 6])
self.engine_client.update_weight_from_disk_result_list = [mock_result1, mock_result2]
request_dict = {"user": "test_user", "passwd": "test_pass", "action": "check_load_weight_result"}
content, status_code = await self.engine_client.check_redundant(request_dict)
self.assertEqual(content["code"], 0)
self.assertEqual(content["msg"], "ok")
self.assertIn("data", content)
# Code does: update_weight_result.value[0].tolist(), so only first elements
self.assertEqual(content["data"], [1, 4])
self.assertEqual(status_code, 200)
async def test_check_redundant_invalid_action(self):
"""Test check_redundant with invalid action."""
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
request_dict = {"user": "test_user", "passwd": "test_pass", "action": "invalid_action"}
content, status_code = await self.engine_client.check_redundant(request_dict)
# For invalid action, content remains None and status_code is HTTPStatus.OK
self.assertIsNone(content)
self.assertEqual(status_code, 200)
def test_init_eplb_signals_non_zero_rank(self):
"""Test init_eplb_signals returns early for non-zero tensor parallel rank."""
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 1 # Non-zero rank
mock_parallel_config.local_data_parallel_id = 0
mock_config = Mock()
mock_config.parallel_config = mock_parallel_config
# Set fd_config to ensure the method checks the correct config
self.engine_client.fd_config = mock_config
self.engine_client.config = mock_config
# Mock IPCSignal to prevent actual file system calls
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
# Should return early without initializing signals
self.engine_client.init_eplb_signals("test_suffix")
# Should not create any IPCSignal instances
mock_ipcsignal.assert_not_called()
# Should return None (implicitly) and not create any signals
self.assertFalse(hasattr(self.engine_client, "rearrange_experts_signal"))
self.assertFalse(hasattr(self.engine_client, "signal_clear_experts_token_stats_list"))
def test_init_eplb_signals_rank_zero_success(self):
"""Test init_eplb_signals successful initialization for rank 0."""
mock_model_config = Mock()
mock_model_config.num_hidden_layers = 12
mock_model_config.moe_num_experts = 8
mock_eplb_config = Mock()
mock_eplb_config.redundant_expert_ip_shm_size = 1024
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_parallel_config.local_data_parallel_id = 2
mock_parallel_config.tensor_parallel_size = 4
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
self.engine_client.fd_config = mock_config # Also set fd_config for proper access
self.engine_client.tensor_parallel_size = 4 # Set this to match the config
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
mock_signal = Mock()
mock_ipcsignal.return_value = mock_signal
self.engine_client.init_eplb_signals("8080")
# Check that IPCSignal was called with correct parameters
# Based on the actual implementation: 4 base signals + 4 TP ranks * 5 signals each = 24 total
self.assertEqual(mock_ipcsignal.call_count, 24) # 4 TP ranks * 5 signals each + 4 base signals = 24 total
# Check that the suffix includes data parallel ID
call_args_list = mock_ipcsignal.call_args_list
dp_suffix_found = any("8080_dp2" in str(call) for call in call_args_list)
self.assertTrue(dp_suffix_found)
# Check that all required signal lists were created
self.assertEqual(len(self.engine_client.signal_clear_experts_token_stats_list), 4)
self.assertEqual(len(self.engine_client.local_experts_token_stats_array_list), 4)
self.assertEqual(len(self.engine_client.expert_tokens_stats_array_list), 4)
self.assertEqual(len(self.engine_client.signal_update_weight_from_disk_array_list), 4)
self.assertEqual(len(self.engine_client.update_weight_from_disk_result_list), 4)
# Check that base signals were created
self.assertTrue(hasattr(self.engine_client, "rearrange_experts_signal"))
self.assertTrue(hasattr(self.engine_client, "rearrange_experts_ips_size_signal"))
self.assertTrue(hasattr(self.engine_client, "shm_rearrange_experts_ips_list"))
self.assertTrue(hasattr(self.engine_client, "signal_update_weight_from_tensor_array"))
def test_init_eplb_signals_array_dimensions(self):
"""Test init_eplb_signals creates arrays with correct dimensions."""
mock_model_config = Mock()
mock_model_config.num_hidden_layers = 6
mock_model_config.moe_num_experts = 4
mock_eplb_config = Mock()
mock_eplb_config.redundant_expert_ip_shm_size = 512
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_parallel_config.local_data_parallel_id = 1
mock_parallel_config.tensor_parallel_size = 2
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
self.engine_client.tensor_parallel_size = 2 # Set this to match mock_parallel_config.tensor_parallel_size
self.engine_client.fd_config = mock_config # Also set fd_config to ensure proper access
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
mock_signal = Mock()
mock_ipcsignal.return_value = mock_signal
self.engine_client.init_eplb_signals("9090")
# Check that IPCSignal was called with arrays of correct shape
call_args_list = mock_ipcsignal.call_args_list
# Find calls for expert token stats arrays (should be 6x4 shape for 2D arrays)
all_experts_token_stats_calls = [call for call in call_args_list if "all_experts_token_stats" in str(call)]
local_experts_token_stats_calls = [
call for call in call_args_list if "local_experts_token_stats" in str(call)
]
# These should be 2D arrays with shape (6, 4)
for call in all_experts_token_stats_calls:
array_arg = call[1]["array"]
self.assertEqual(array_arg.shape, (6, 4)) # (num_hidden_layers, moe_num_experts)
for call in local_experts_token_stats_calls:
array_arg = call[1]["array"]
self.assertEqual(array_arg.shape, (6, 4)) # (num_hidden_layers, moe_num_experts)
# Check that single-element signals have shape (1,)
single_element_calls = [
call
for call in call_args_list
if "rearrange_experts_status" in str(call)
or "rearrange_experts_ips_size" in str(call)
or "signal_update_weight_from_tensor" in str(call)
]
for call in single_element_calls:
array_arg = call[1]["array"]
self.assertEqual(array_arg.shape, (1,)) # Single element array
def test_init_eplb_signals_suffix_format(self):
"""Test init_eplb_signals uses correct suffix format."""
mock_model_config = Mock()
mock_model_config.num_hidden_layers = 4
mock_model_config.moe_num_experts = 2
mock_eplb_config = Mock()
mock_eplb_config.redundant_expert_ip_shm_size = 256
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_parallel_config.local_data_parallel_id = 3
mock_parallel_config.tensor_parallel_size = 1
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
self.engine_client.fd_config = mock_config # Set fd_config as well
# Ensure tensor_parallel_size is set correctly
self.engine_client.tensor_parallel_size = 1
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
mock_signal = Mock()
mock_ipcsignal.return_value = mock_signal
self.engine_client.init_eplb_signals("7777")
# Check suffix format
call_args_list = mock_ipcsignal.call_args_list
# Check DP suffix
dp_calls = [call for call in call_args_list if "rearrange_experts_status" in str(call)]
self.assertEqual(len(dp_calls), 1)
self.assertEqual(dp_calls[0][1]["suffix"], "7777_dp3")
# Check TP suffix for TP rank 0
tp_calls = [call for call in call_args_list if "signal_clear_experts_token_stats" in str(call)]
self.assertEqual(len(tp_calls), 1)
self.assertEqual(tp_calls[0][1]["suffix"], "7777_dp3_tp0")
def test_init_eplb_signals_list_initialization(self):
"""Test init_eplb_signals properly initializes all signal lists."""
mock_model_config = Mock()
mock_model_config.num_hidden_layers = 2
mock_model_config.moe_num_experts = 2
mock_eplb_config = Mock()
mock_eplb_config.redundant_expert_ip_shm_size = 128
mock_parallel_config = Mock()
mock_parallel_config.tensor_parallel_rank = 0
mock_parallel_config.local_data_parallel_id = 0
mock_parallel_config.tensor_parallel_size = 3
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config = mock_parallel_config
self.engine_client.config = mock_config
self.engine_client.tensor_parallel_size = 3 # Set this to match mock_parallel_config.tensor_parallel_size
self.engine_client.fd_config = mock_config # Also set fd_config to ensure proper access
# Ensure lists start empty
self.engine_client.signal_clear_experts_token_stats_list = []
self.engine_client.local_experts_token_stats_array_list = []
self.engine_client.expert_tokens_stats_array_list = []
self.engine_client.signal_update_weight_from_disk_array_list = []
self.engine_client.update_weight_from_disk_result_list = []
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
mock_signal = Mock()
mock_ipcsignal.return_value = mock_signal
self.engine_client.init_eplb_signals("6666")
# Check that all lists have correct length (3 TP ranks)
self.assertEqual(len(self.engine_client.signal_clear_experts_token_stats_list), 3)
self.assertEqual(len(self.engine_client.local_experts_token_stats_array_list), 3)
self.assertEqual(len(self.engine_client.expert_tokens_stats_array_list), 3)
self.assertEqual(len(self.engine_client.signal_update_weight_from_disk_array_list), 3)
self.assertEqual(len(self.engine_client.update_weight_from_disk_result_list), 3)
async def test_init_iluvatar_platform(self):
"""Test EngineClient initialization on Iluvatar platform."""
mock_model_config = Mock()
mock_model_config.enable_mm = False
mock_config = Mock()
mock_config.model_config = mock_model_config
mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False
with (
patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class,
patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform,
patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal,
patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs,
patch("os.getenv", return_value="50"),
):
mock_platform.is_iluvatar.return_value = True # Iluvatar platform
mock_input_processor = Mock()
mock_processor_class.return_value = mock_input_processor
mock_processor = Mock()
mock_input_processor.create_processor.return_value = mock_processor
mock_signal_instance = Mock()
mock_signal_instance.value = np.array([0])
mock_ipcsignal.return_value = mock_signal_instance
mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100
client = EngineClient(
model_name_or_path="test_model",
tokenizer=Mock(),
max_model_len=2048,
tensor_parallel_size=1,
pid=5678,
port=8080,
limit_mm_per_prompt=5,
mm_processor_kwargs={},
config=mock_config,
reasoning_parser=None,
data_parallel_size=1,
enable_logprob=True,
workers=1,
tool_parser=None,
enable_prefix_caching=False,
splitwise_role=None,
max_processor_cache=0,
)
self.assertTrue(client.is_master) # With 1 tensor_parallel_size, should be master even on Iluvatar
def test_check_mm_disable_prefix_cache_without_multimodal_data(self):
"""Test _check_mm_disable_prefix_cache without multimodal data."""
self.engine_client.disable_prefix_mm = True
task = {"multimodal_inputs": {"token_type_ids": [0, 0, 0]}} # Sum = 0
result = self.engine_client._check_mm_disable_prefix_cache(task)
self.assertFalse(result)
async def test_add_requests_multimodal_prefix_cache_error(self):
"""Test add_requests with multimodal data when prefix cache is enabled."""
self.engine_client.enable_mm = True
self.engine_client.enable_prefix_caching = True
self.engine_client.disable_prefix_mm = True
self.engine_client.data_processor = Mock()
self.engine_client.data_processor.process_request_dict = Mock()
task = {
"request_id": "test_request",
"user": "test_user",
"multimodal_inputs": {"token_type_ids": [1, 1, 0, 1]}, # Multimodal data present
"prompt_token_ids": [1, 2, 3],
"max_tokens": 100,
}
with self.assertRaises(EngineError) as context:
await self.engine_client.add_requests(task)
self.assertIn("does not support processing requests containing multimodal data", str(context.exception))
self.assertEqual(context.exception.error_code, 400)
async def test_add_requests_input_too_long_error(self):
"""Test add_requests with input length too long."""
self.engine_client.max_model_len = 10
self.engine_client.data_processor = Mock()
self.engine_client.data_processor.process_request_dict = Mock()
task = {
"request_id": "test_request",
"user": "test_user",
"prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8], # length = 8
"max_tokens": 5, # 8 + 5 = 13 >= 10
"min_tokens": 2,
}
with self.assertRaises(EngineError) as context:
await self.engine_client.add_requests(task)
self.assertIn("Input text is too long", str(context.exception))
self.assertIn("input_ids_len (8) + min_tokens(2) >= max_model_len(10)", str(context.exception))
self.assertEqual(context.exception.error_code, 400)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_MAX_STOP_SEQS_NUM", 3)
async def test_add_requests_stop_seqs_num_exceeds_limit(self):
"""Test add_requests with stop sequences number exceeding limit."""
self.engine_client.data_processor = Mock()
self.engine_client.data_processor.process_request_dict = Mock()
task = {
"request_id": "test_request",
"user": "test_user",
"prompt_token_ids": [1, 2, 3],
"max_tokens": 10,
"stop_seqs_len": [10, 20, 30, 40], # 4 sequences > limit of 3
}
with self.assertRaises(EngineError) as context:
await self.engine_client.add_requests(task)
self.assertIn(
"Length of stop ([10, 20, 30, 40]) exceeds the limit max_stop_seqs_num(3)", str(context.exception)
)
self.assertIn("Please reduce the number of stop or set a lager max_stop_seqs_num", str(context.exception))
self.assertEqual(context.exception.error_code, 400)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_STOP_SEQS_MAX_LEN", 5)
async def test_add_requests_single_stop_seq_len_exceeds_limit(self):
"""Test add_requests with single stop sequence length exceeding limit."""
self.engine_client.data_processor = Mock()
self.engine_client.data_processor.process_request_dict = Mock()
task = {
"request_id": "test_request",
"user": "test_user",
"prompt_token_ids": [1, 2, 3],
"max_tokens": 10,
"stop_seqs_len": [3, 10, 2], # 10 > limit of 5
}
with self.assertRaises(EngineError) as context:
await self.engine_client.add_requests(task)
self.assertIn("Length of stop_seqs(10) exceeds the limit stop_seqs_max_len(5)", str(context.exception))
self.assertIn(
"Please reduce the length of stop sequences or set a larger stop_seqs_max_len", str(context.exception)
)
self.assertEqual(context.exception.error_code, 400)
async def test_rearrange_experts_eplb_disabled(self):
"""Test rearrange_experts when EPLB is disabled."""
# Mock eplb_config with enable_eplb = False
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = False
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
self.engine_client.config = mock_config
request_dict = {"user": "test_user", "passwd": "test_pass"}
content, status_code = await self.engine_client.rearrange_experts(request_dict)
expected_content = {"code": 1, "msg": "redundant expert is disabled"}
self.assertEqual(content, expected_content)
self.assertEqual(status_code.value, 400) # BAD_REQUEST
async def test_rearrange_experts_invalid_credentials(self):
"""Test rearrange_experts with invalid user/password."""
# Mock eplb_config with enable_eplb = True
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "valid_user"
mock_eplb_config.redundant_expert_api_password = "valid_pass"
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config.tensor_parallel_rank = 0
self.engine_client.config = mock_config
request_dict = {"user": "invalid_user", "passwd": "invalid_pass"}
content, status_code = await self.engine_client.rearrange_experts(request_dict)
expected_content = {"code": 1, "msg": "user or passwd is invalid"}
self.assertEqual(content, expected_content)
self.assertEqual(status_code.value, 401) # UNAUTHORIZED
async def test_rearrange_experts_non_rank_zero(self):
"""Test rearrange_experts from non-zero rank."""
# Mock eplb_config with enable_eplb = True
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config.tensor_parallel_rank = 2 # Non-zero rank
self.engine_client.config = mock_config
request_dict = {"user": "test_user", "passwd": "test_pass"}
content, status_code = await self.engine_client.rearrange_experts(request_dict)
expected_content = {"code": 1, "msg": "actual rank 2, expect rank 0"}
self.assertEqual(content, expected_content)
self.assertEqual(status_code.value, 400) # BAD_REQUEST
async def test_rearrange_experts_recv_expert_weight_invalid_data(self):
"""Test rearrange_experts recv_expert_weight action with invalid data."""
# Mock eplb_config
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config.tensor_parallel_rank = 0
self.engine_client.config = mock_config
request_dict = {
"user": "test_user",
"passwd": "test_pass",
"action": "recv_expert_weight",
# Missing "data" field
}
content, status_code = await self.engine_client.rearrange_experts(request_dict)
expected_content = {"code": 1, "msg": "data not in request or data is not a list"}
self.assertEqual(content, expected_content)
self.assertEqual(status_code.value, 400) # BAD_REQUEST
async def test_rearrange_experts_invalid_action(self):
"""Test rearrange_experts with invalid action."""
# Mock eplb_config
mock_eplb_config = Mock()
mock_eplb_config.enable_eplb = True
mock_eplb_config.redundant_expert_api_user = "test_user"
mock_eplb_config.redundant_expert_api_password = "test_pass"
mock_config = Mock()
mock_config.eplb_config = mock_eplb_config
mock_config.parallel_config.tensor_parallel_rank = 0
self.engine_client.config = mock_config
request_dict = {"user": "test_user", "passwd": "test_pass", "action": "invalid_action"}
content, status_code = await self.engine_client.rearrange_experts(request_dict)
expected_content = {"code": 1, "msg": "invalid action invalid_action"}
self.assertEqual(content, expected_content)
self.assertEqual(status_code.value, 400) # BAD_REQUEST
if __name__ == "__main__":
unittest.main()