mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* [Optimize] Improve perf for fd response token with internal adapter * fix * fix bug * fix ci * fix ci * fix ci * fix ci
669 lines
25 KiB
Python
669 lines
25 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from multiprocessing import Queue
|
|
from unittest.mock import Mock, patch
|
|
|
|
# Mock all external dependencies before importing anything
|
|
mock_logger = Mock()
|
|
|
|
|
|
# Create a proper mock for FD_EP_BATCHED_TOKEN_TIMEOUT that can be compared with float
|
|
class MockEnv:
|
|
FD_EP_BATCHED_TOKEN_TIMEOUT = 0.1
|
|
|
|
|
|
mock_envs = MockEnv()
|
|
|
|
# Mock threading module to prevent real thread creation
|
|
import threading
|
|
|
|
mock_threading = Mock()
|
|
sys.modules["threading"] = mock_threading
|
|
mock_threading.Thread = Mock()
|
|
mock_threading.Lock = Mock(return_value=Mock())
|
|
mock_threading.Condition = Mock(return_value=Mock())
|
|
|
|
# Create mock modules
|
|
sys.modules["fastdeploy"] = Mock()
|
|
sys.modules["fastdeploy.utils"] = Mock()
|
|
sys.modules["fastdeploy.envs"] = mock_envs
|
|
sys.modules["fastdeploy.engine"] = Mock()
|
|
sys.modules["fastdeploy.engine.request"] = Mock()
|
|
sys.modules["fastdeploy.scheduler"] = Mock()
|
|
sys.modules["fastdeploy.scheduler.local_scheduler"] = Mock()
|
|
sys.modules["fastdeploy.scheduler.data"] = Mock()
|
|
|
|
# Mock the get_logger function
|
|
sys.modules["fastdeploy.utils"].get_logger = Mock(return_value=mock_logger)
|
|
|
|
|
|
# Mock the Request, RequestOutput, and ScheduledResponse classes
|
|
class MockRequest:
|
|
def __init__(self, request_id, prompt_tokens_ids_len=10):
|
|
self.request_id = request_id
|
|
self.prompt_tokens_ids_len = prompt_tokens_ids_len
|
|
self.schedule_time = time.time()
|
|
self.raw = self
|
|
|
|
|
|
class MockRequestOutput:
|
|
def __init__(self, request_id, finished=False):
|
|
self.request_id = request_id
|
|
self.finished = finished
|
|
|
|
|
|
class MockScheduledResponse:
|
|
def __init__(self, request_output):
|
|
self.request_id = request_output.request_id
|
|
self.finished = request_output.finished
|
|
self.raw = self
|
|
|
|
|
|
# Mock LocalScheduler base class
|
|
class MockLocalScheduler:
|
|
def __init__(
|
|
self,
|
|
max_size,
|
|
ttl,
|
|
enable_chunked_prefill,
|
|
max_num_partial_prefills,
|
|
max_long_partial_prefills,
|
|
long_prefill_token_threshold,
|
|
):
|
|
self.max_size = max_size
|
|
self.ttl = ttl
|
|
self.mutex = threading.Lock()
|
|
self.requests = {}
|
|
self.responses = {}
|
|
self.ids = []
|
|
self.ids_read_cursor = 0
|
|
self.requests_not_empty = threading.Condition()
|
|
self.responses_not_empty = threading.Condition()
|
|
self.batch_responses_per_step = list()
|
|
|
|
def calc_required_blocks(self, token_len, block_size):
|
|
return (token_len + block_size - 1) // block_size
|
|
|
|
def put_requests(self, requests):
|
|
with self.mutex:
|
|
for request in requests:
|
|
if request.request_id not in self.requests:
|
|
self.requests[request.request_id] = request
|
|
self.ids.append(request.request_id)
|
|
with self.requests_not_empty:
|
|
self.requests_not_empty.notify_all()
|
|
|
|
def get_results(self):
|
|
with self.responses_not_empty:
|
|
# Don't actually wait, just check if there are responses
|
|
if any(self.responses.values()):
|
|
results = []
|
|
for response_list in list(self.responses.values()):
|
|
results.extend(response_list)
|
|
self.responses.clear()
|
|
return results
|
|
return []
|
|
|
|
def _recycle(self, request_id=None):
|
|
"""Mock implementation of _recycle method."""
|
|
if request_id is not None:
|
|
self.requests.pop(request_id, None)
|
|
self.responses.pop(request_id, None)
|
|
if hasattr(self, "splitwise_role") and self.splitwise_role == "decode":
|
|
return
|
|
if request_id in self.ids:
|
|
self.ids.pop(self.ids.index(request_id))
|
|
self.ids_read_cursor = max(0, self.ids_read_cursor - 1)
|
|
return
|
|
|
|
if self.max_size <= 0:
|
|
return
|
|
|
|
if len(self.requests) <= self.max_size:
|
|
return
|
|
|
|
now = time.time()
|
|
expired_ids = []
|
|
for req_id in self.ids:
|
|
if req_id in self.requests:
|
|
request = self.requests[req_id]
|
|
if now - request.schedule_time >= self.ttl:
|
|
expired_ids.append(req_id)
|
|
else:
|
|
break
|
|
|
|
for expired_id in expired_ids:
|
|
self.requests.pop(expired_id, None)
|
|
self.responses.pop(expired_id, None)
|
|
if expired_id in self.ids:
|
|
self.ids.pop(self.ids.index(expired_id))
|
|
|
|
if len(expired_ids) > 0:
|
|
self.ids_read_cursor = max(0, self.ids_read_cursor - len(expired_ids))
|
|
|
|
|
|
# Set up the mock classes in the modules
|
|
sys.modules["fastdeploy.engine.request"].Request = MockRequest
|
|
sys.modules["fastdeploy.engine.request"].RequestOutput = MockRequestOutput
|
|
sys.modules["fastdeploy.scheduler.data"].ScheduledResponse = MockScheduledResponse
|
|
sys.modules["fastdeploy.scheduler.local_scheduler"].LocalScheduler = MockLocalScheduler
|
|
|
|
# Now we can import the dp_scheduler module with all dependencies mocked
|
|
import importlib.util
|
|
import os
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
|
"dp_scheduler", os.path.join(os.path.dirname(__file__), "../../fastdeploy/scheduler/dp_scheduler.py")
|
|
)
|
|
dp_scheduler_module = importlib.util.module_from_spec(spec)
|
|
|
|
# Mock the dependencies in the module
|
|
dp_scheduler_module.envs = mock_envs
|
|
dp_scheduler_module.get_logger = Mock(return_value=mock_logger)
|
|
dp_scheduler_module.threading = mock_threading # Add threading to the module
|
|
|
|
# Execute the module
|
|
spec.loader.exec_module(dp_scheduler_module)
|
|
|
|
# Extract the classes we want to test
|
|
DPLocalScheduler = dp_scheduler_module.DPLocalScheduler
|
|
DPScheduler = dp_scheduler_module.DPScheduler
|
|
|
|
# Override the scheduler_logger to use our mock
|
|
original_init = DPLocalScheduler.__init__
|
|
|
|
|
|
def patched_init(self, *args, **kwargs):
|
|
original_init(self, *args, **kwargs)
|
|
self.scheduler_logger = mock_logger
|
|
|
|
|
|
DPLocalScheduler.__init__ = patched_init
|
|
|
|
|
|
class TestDPLocalScheduler(unittest.TestCase):
|
|
"""Test cases for DPLocalScheduler class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.scheduler = DPLocalScheduler(
|
|
max_size=100,
|
|
ttl=60,
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=4,
|
|
max_long_partial_prefills=2,
|
|
long_prefill_token_threshold=1024,
|
|
splitwise_role="prefill",
|
|
)
|
|
|
|
def test_initialization_with_default_role(self):
|
|
"""Test scheduler initialization with default splitwise_role."""
|
|
scheduler = DPLocalScheduler(
|
|
max_size=50,
|
|
ttl=30,
|
|
enable_chunked_prefill=False,
|
|
max_num_partial_prefills=2,
|
|
max_long_partial_prefills=1,
|
|
long_prefill_token_threshold=512,
|
|
)
|
|
self.assertEqual(scheduler.splitwise_role, "prefill")
|
|
self.assertEqual(scheduler.max_size, 50)
|
|
self.assertEqual(scheduler.ttl, 30)
|
|
|
|
def test_initialization_with_custom_role(self):
|
|
"""Test scheduler initialization with custom splitwise_role."""
|
|
scheduler = DPLocalScheduler(
|
|
max_size=50,
|
|
ttl=30,
|
|
enable_chunked_prefill=False,
|
|
max_num_partial_prefills=2,
|
|
max_long_partial_prefills=1,
|
|
long_prefill_token_threshold=512,
|
|
splitwise_role="decode",
|
|
)
|
|
self.assertEqual(scheduler.splitwise_role, "decode")
|
|
|
|
def test_put_results_with_finished_requests(self):
|
|
"""Test putting results with finished requests."""
|
|
# Reset mock logger
|
|
mock_logger.reset_mock()
|
|
|
|
# Create mock request outputs
|
|
results = [
|
|
MockRequestOutput("req1", finished=True),
|
|
MockRequestOutput("req2", finished=False),
|
|
MockRequestOutput("req3", finished=True),
|
|
]
|
|
|
|
# Put results - this should work without threading issues since we're using the real implementation
|
|
with patch.object(self.scheduler, "responses_not_empty"):
|
|
self.scheduler.put_results(results)
|
|
|
|
# Check that finished requests were logged - the logger should have been called
|
|
self.assertTrue(mock_logger.info.called)
|
|
# Get the actual call arguments to verify the message format
|
|
call_args = mock_logger.info.call_args[0][0]
|
|
self.assertIn("finished responses", call_args)
|
|
self.assertIn("req1", call_args)
|
|
self.assertIn("req3", call_args)
|
|
|
|
def test_put_results_with_new_responses(self):
|
|
"""Test putting results with new responses."""
|
|
results = [MockRequestOutput("new_req", finished=False)]
|
|
|
|
# Initially no responses
|
|
self.assertNotIn("new_req", self.scheduler.responses)
|
|
|
|
# Put results - mock the condition variable to avoid threading issues
|
|
with patch.object(self.scheduler, "responses_not_empty"):
|
|
self.scheduler.put_results(results)
|
|
|
|
# Check response was added
|
|
self.assertIn("new_req", self.scheduler.responses)
|
|
self.assertEqual(len(self.scheduler.responses["new_req"]), 1)
|
|
|
|
def test_put_results_with_existing_responses(self):
|
|
"""Test putting results with existing responses."""
|
|
results1 = [MockRequestOutput("existing_req", finished=False)]
|
|
results2 = [MockRequestOutput("existing_req", finished=True)]
|
|
|
|
# Put first set of results - mock the condition variable to avoid threading issues
|
|
with patch.object(self.scheduler, "responses_not_empty"):
|
|
self.scheduler.put_results(results1)
|
|
self.assertEqual(len(self.scheduler.responses["existing_req"]), 1)
|
|
|
|
# Put second set of results
|
|
self.scheduler.put_results(results2)
|
|
self.assertEqual(len(self.scheduler.responses["existing_req"]), 2)
|
|
|
|
def test_recycle_specific_request_id(self):
|
|
"""Test recycling a specific request ID."""
|
|
# Add some test data
|
|
self.scheduler.requests["req1"] = MockRequest("req1")
|
|
self.scheduler.responses["req1"] = [MockScheduledResponse(MockRequestOutput("req1"))]
|
|
self.scheduler.ids = ["req1", "req2"]
|
|
self.scheduler.ids_read_cursor = 1
|
|
|
|
# Recycle specific request
|
|
self.scheduler._recycle("req1")
|
|
|
|
# Verify request was removed
|
|
self.assertNotIn("req1", self.scheduler.requests)
|
|
self.assertNotIn("req1", self.scheduler.responses)
|
|
self.assertEqual(self.scheduler.ids, ["req2"])
|
|
self.assertEqual(self.scheduler.ids_read_cursor, 0)
|
|
|
|
def test_recycle_specific_request_id_decode_role(self):
|
|
"""Test recycling a specific request ID in decode role."""
|
|
scheduler = DPLocalScheduler(
|
|
max_size=100,
|
|
ttl=60,
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=4,
|
|
max_long_partial_prefills=2,
|
|
long_prefill_token_threshold=1024,
|
|
splitwise_role="decode",
|
|
)
|
|
|
|
# Add some test data
|
|
scheduler.requests["req1"] = MockRequest("req1")
|
|
scheduler.responses["req1"] = [MockScheduledResponse(MockRequestOutput("req1"))]
|
|
scheduler.ids = ["req1", "req2"]
|
|
scheduler.ids_read_cursor = 1
|
|
|
|
# Recycle specific request (should not modify ids in decode role)
|
|
scheduler._recycle("req1")
|
|
|
|
# Verify request and response were removed but ids unchanged
|
|
self.assertNotIn("req1", scheduler.requests)
|
|
self.assertNotIn("req1", scheduler.responses)
|
|
self.assertEqual(scheduler.ids, ["req1", "req2"]) # Should not change in decode role
|
|
self.assertEqual(scheduler.ids_read_cursor, 1) # Should not change in decode role
|
|
|
|
def test_recycle_with_max_size_zero(self):
|
|
"""Test recycling when max_size is 0 (unlimited)."""
|
|
scheduler = DPLocalScheduler(
|
|
max_size=0,
|
|
ttl=60,
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=4,
|
|
max_long_partial_prefills=2,
|
|
long_prefill_token_threshold=1024,
|
|
)
|
|
|
|
# Add test data
|
|
scheduler.requests["req1"] = MockRequest("req1")
|
|
scheduler.responses["req1"] = [MockScheduledResponse(MockRequestOutput("req1"))]
|
|
scheduler.ids = ["req1"]
|
|
|
|
# Should return early without recycling
|
|
scheduler._recycle()
|
|
|
|
# Data should remain unchanged
|
|
self.assertIn("req1", scheduler.requests)
|
|
self.assertIn("req1", scheduler.responses)
|
|
|
|
def test_recycle_under_max_size(self):
|
|
"""Test recycling when under max_size limit."""
|
|
# Add test data under limit
|
|
self.scheduler.requests["req1"] = MockRequest("req1")
|
|
self.scheduler.requests["req2"] = MockRequest("req2")
|
|
self.scheduler.ids = ["req1", "req2"]
|
|
|
|
# Should return early without recycling
|
|
self.scheduler._recycle()
|
|
|
|
# Data should remain unchanged
|
|
self.assertIn("req1", self.scheduler.requests)
|
|
self.assertIn("req2", self.scheduler.requests)
|
|
|
|
@patch("time.time")
|
|
def test_recycle_expired_requests(self, mock_time):
|
|
"""Test recycling expired requests."""
|
|
# Create a scheduler with smaller max_size to trigger recycling
|
|
scheduler = DPLocalScheduler(
|
|
max_size=1, # Set to 1 to trigger recycling when we have 2 requests
|
|
ttl=60,
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=4,
|
|
max_long_partial_prefills=2,
|
|
long_prefill_token_threshold=1024,
|
|
)
|
|
|
|
# Mock time to make requests appear expired
|
|
mock_time.return_value = 100.0
|
|
|
|
# Create expired request (schedule_time = 50.0, ttl = 60, so expired)
|
|
expired_request = MockRequest("expired_req")
|
|
expired_request.schedule_time = 30.0 # 70 seconds ago (beyond ttl=60)
|
|
|
|
# Create non-expired request
|
|
fresh_request = MockRequest("fresh_req")
|
|
fresh_request.schedule_time = 80.0 # 20 seconds ago (within ttl=60)
|
|
|
|
# Add test data
|
|
scheduler.requests["expired_req"] = expired_request
|
|
scheduler.requests["fresh_req"] = fresh_request
|
|
scheduler.ids = ["expired_req", "fresh_req"]
|
|
scheduler.ids_read_cursor = 2
|
|
|
|
# Recycle expired requests
|
|
scheduler._recycle()
|
|
|
|
# Verify expired request was removed, fresh request remains
|
|
self.assertNotIn("expired_req", scheduler.requests)
|
|
self.assertIn("fresh_req", scheduler.requests)
|
|
self.assertEqual(scheduler.ids, ["fresh_req"])
|
|
self.assertEqual(scheduler.ids_read_cursor, 1)
|
|
|
|
def test_get_requests_insufficient_resources(self):
|
|
"""Test getting requests when resources are insufficient."""
|
|
mock_logger.reset_mock()
|
|
|
|
# Test with insufficient blocks - mock the condition variable to avoid threading issues
|
|
with patch.object(self.scheduler, "requests_not_empty"):
|
|
requests = self.scheduler.get_requests(
|
|
available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
|
)
|
|
|
|
self.assertEqual(requests, [])
|
|
# The logger should have been called for insufficient resources
|
|
self.assertTrue(mock_logger.debug.called)
|
|
# Check the message contains expected content
|
|
call_args = mock_logger.debug.call_args[0][0]
|
|
self.assertIn("insufficient", call_args.lower())
|
|
|
|
def test_get_requests_insufficient_batch(self):
|
|
"""Test getting requests when batch size is insufficient."""
|
|
with patch.object(self.scheduler, "requests_not_empty"):
|
|
requests = self.scheduler.get_requests(
|
|
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0
|
|
)
|
|
|
|
self.assertEqual(requests, [])
|
|
|
|
@patch("time.time")
|
|
@patch.object(dp_scheduler_module, "envs")
|
|
def test_get_requests_no_requests_available(self, mock_envs, mock_time):
|
|
"""Test getting requests when no requests are available."""
|
|
# Mock envs to return our mock environment
|
|
mock_envs.FD_EP_BATCHED_TOKEN_TIMEOUT = 0.1
|
|
|
|
# Mock time to return consistent values - provide enough values for multiple calls
|
|
time_values = [100.0, 100.1, 100.2, 100.3, 100.4, 100.5] # Multiple values for the loop
|
|
mock_time.side_effect = time_values
|
|
|
|
# Mock the condition variable to avoid threading issues
|
|
with patch.object(self.scheduler, "requests_not_empty"):
|
|
requests = self.scheduler.get_requests(
|
|
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
|
)
|
|
|
|
# Should return empty list after timeout
|
|
self.assertEqual(requests, [])
|
|
|
|
def test_get_requests_successful_batching(self):
|
|
"""Test successful request batching."""
|
|
# Add a mock request
|
|
mock_request = MockRequest("test_req", prompt_tokens_ids_len=10)
|
|
self.scheduler.requests["test_req"] = mock_request
|
|
self.scheduler.ids = ["test_req"]
|
|
|
|
# Mock calc_required_blocks to return small value
|
|
self.scheduler.calc_required_blocks = Mock(return_value=1)
|
|
|
|
requests = self.scheduler.get_requests(
|
|
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
|
)
|
|
|
|
# Should get the request
|
|
self.assertEqual(len(requests), 1)
|
|
self.assertEqual(requests[0].request_id, "test_req")
|
|
|
|
@patch("time.time")
|
|
@patch.object(dp_scheduler_module, "envs")
|
|
def test_get_requests_timeout(self, mock_envs, mock_time):
|
|
"""Test request batching with timeout."""
|
|
# Mock envs to return our mock environment
|
|
mock_envs.FD_EP_BATCHED_TOKEN_TIMEOUT = 0.1
|
|
|
|
# Mock time to return consistent values - provide enough values for multiple calls
|
|
time_values = [100.0, 100.1, 100.2, 100.3, 100.4, 100.5] # Multiple values for the loop
|
|
mock_time.side_effect = time_values
|
|
|
|
# Add a mock request
|
|
mock_request = MockRequest("test_req", prompt_tokens_ids_len=10)
|
|
self.scheduler.requests["test_req"] = mock_request
|
|
self.scheduler.ids = ["test_req"]
|
|
|
|
# Mock calc_required_blocks to return large value to exceed available blocks
|
|
self.scheduler.calc_required_blocks = Mock(return_value=50)
|
|
|
|
# Mock the condition variable to avoid threading issues
|
|
with patch.object(self.scheduler, "requests_not_empty"):
|
|
requests = self.scheduler.get_requests(
|
|
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
|
)
|
|
|
|
# Should return empty due to timeout
|
|
self.assertEqual(requests, [])
|
|
|
|
|
|
class TestDPScheduler(unittest.TestCase):
|
|
"""Test cases for DPScheduler class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.dp_scheduler = DPScheduler(
|
|
max_size=100,
|
|
ttl=60,
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=4,
|
|
max_long_partial_prefills=2,
|
|
long_prefill_token_threshold=1024,
|
|
splitwise_role="prefill",
|
|
)
|
|
|
|
def test_initialization(self):
|
|
"""Test DPScheduler initialization."""
|
|
self.assertIsNotNone(self.dp_scheduler._scheduler)
|
|
self.assertEqual(self.dp_scheduler._scheduler.splitwise_role, "prefill")
|
|
|
|
def test_get_unhandled_request_num(self):
|
|
"""Test getting number of unhandled requests."""
|
|
# Initially should be 0
|
|
self.assertEqual(self.dp_scheduler.get_unhandled_request_num(), 0)
|
|
|
|
# Add a request to the internal scheduler
|
|
mock_request = MockRequest("test_req")
|
|
self.dp_scheduler._scheduler.requests["test_req"] = mock_request
|
|
|
|
# Should return 1
|
|
self.assertEqual(self.dp_scheduler.get_unhandled_request_num(), 1)
|
|
|
|
def test_put_results(self):
|
|
"""Test putting results to DPScheduler."""
|
|
results = [MockRequestOutput("test_req", finished=True)]
|
|
|
|
# Should not raise an exception - mock the condition variable to avoid threading issues
|
|
with patch.object(self.dp_scheduler._scheduler, "responses_not_empty"):
|
|
self.dp_scheduler.put_results(results)
|
|
|
|
# Verify results were added to the internal scheduler
|
|
self.assertIn("test_req", self.dp_scheduler._scheduler.responses)
|
|
|
|
def test_get_requests_delegates_to_scheduler(self):
|
|
"""Test that get_requests delegates to internal scheduler."""
|
|
# Mock the internal scheduler's get_requests method
|
|
expected_requests = [MockRequest("test_req")]
|
|
self.dp_scheduler._scheduler.get_requests = Mock(return_value=expected_requests)
|
|
|
|
requests = self.dp_scheduler.get_requests(
|
|
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
|
)
|
|
|
|
# Verify delegation
|
|
self.dp_scheduler._scheduler.get_requests.assert_called_once_with(20, 16, 10, 1024, 1)
|
|
self.assertEqual(requests, expected_requests)
|
|
|
|
def test_put_requests_missing_dp_rank(self):
|
|
"""Test put_requests raises error when dp_rank is missing."""
|
|
# Create a request without dp_rank attribute
|
|
mock_request = MockRequest("test_req")
|
|
|
|
requests = [mock_request]
|
|
|
|
# Should raise ValueError
|
|
with self.assertRaises(ValueError) as cm:
|
|
self.dp_scheduler.put_requests(requests)
|
|
|
|
self.assertIn("missing the 'dp_rank' attribute", str(cm.exception))
|
|
|
|
@patch("threading.Thread")
|
|
def test_put_requests_success(self, mock_thread):
|
|
"""Test successful put_requests with dp_rank."""
|
|
# Create request queues - use Mock instead of real Queue to avoid threading issues
|
|
request_queues = [Mock(), Mock(), Mock()]
|
|
result_queue = Mock()
|
|
|
|
# Start the scheduler - this will create mocked threads
|
|
self.dp_scheduler.start(0, request_queues, result_queue)
|
|
|
|
# Create requests with dp_rank
|
|
mock_request1 = MockRequest("test_req1")
|
|
mock_request1.dp_rank = 0
|
|
mock_request2 = MockRequest("test_req2")
|
|
mock_request2.dp_rank = 1
|
|
|
|
requests = [mock_request1, mock_request2]
|
|
|
|
# Should not raise an exception
|
|
results = self.dp_scheduler.put_requests(requests)
|
|
|
|
# Verify results format
|
|
self.assertEqual(len(results), 2)
|
|
self.assertEqual(results[0], ("test_req1", None))
|
|
self.assertEqual(results[1], ("test_req2", None))
|
|
|
|
# Verify requests were put to the correct queues
|
|
request_queues[0].put.assert_called_once_with(mock_request1)
|
|
request_queues[1].put.assert_called_once_with(mock_request2)
|
|
|
|
@patch("threading.Thread")
|
|
def test_start_creates_threads(self, mock_thread):
|
|
"""Test that start creates and starts threads."""
|
|
mock_thread.return_value = Mock()
|
|
|
|
request_queues = [Queue(), Queue()]
|
|
result_queue = Queue()
|
|
|
|
self.dp_scheduler.start(0, request_queues, result_queue)
|
|
|
|
# Should create 2 threads
|
|
self.assertEqual(mock_thread.call_count, 2)
|
|
|
|
# Both threads should be started
|
|
mock_thread.return_value.start.assert_called()
|
|
|
|
|
|
class TestDPIntegration(unittest.TestCase):
|
|
"""Integration tests for DP Scheduler functionality."""
|
|
|
|
def test_end_to_end_request_flow(self):
|
|
"""Test end-to-end request flow through DP scheduler - without real threads."""
|
|
# Create DP scheduler
|
|
dp_scheduler = DPScheduler(
|
|
max_size=10,
|
|
ttl=30,
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=2,
|
|
max_long_partial_prefills=1,
|
|
long_prefill_token_threshold=512,
|
|
)
|
|
|
|
# Mock the start method to avoid creating real threads
|
|
with patch.object(dp_scheduler, "start") as mock_start:
|
|
# Set up test data directly
|
|
dp_scheduler.dp_rank = 0
|
|
dp_scheduler.request_queues = [Mock(), Mock()]
|
|
dp_scheduler.result_queue = Mock()
|
|
dp_scheduler.scheduler_logger = mock_logger
|
|
dp_scheduler._scheduler.scheduler_logger = mock_logger
|
|
|
|
# Test basic functionality without real threads
|
|
mock_request = MockRequest("integration_req")
|
|
mock_request.dp_rank = 0
|
|
|
|
# Mock the request_queues to avoid real Queue operations
|
|
dp_scheduler.request_queues[0].put = Mock()
|
|
|
|
# Test put_requests functionality
|
|
results = dp_scheduler.put_requests([mock_request])
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(results[0], ("integration_req", None))
|
|
|
|
# Verify the request was put to the correct queue
|
|
dp_scheduler.request_queues[0].put.assert_called_once_with(mock_request)
|
|
|
|
# Verify start method was not called (to avoid threads)
|
|
mock_start.assert_not_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|