Files
FastDeploy/fastdeploy/scheduler/local_scheduler.py
2025-06-29 23:29:37 +00:00

371 lines
14 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
from typing import Dict, List, Optional, Tuple
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.utils import llm_logger
class LocalScheduler(object):
"""
A local in-memory task scheduler for request/response management.
This class provides functionality for:
- Enqueuing and dequeuing requests
- Managing request lifecycle with TTL
- Handling request/response flow
- Thread-safe operations with condition variables
"""
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
):
"""
Initializes a local in-memory scheduler for managing inference requests.
Args:
max_size: Maximum number of concurrent requests the scheduler can handle (0 for unlimited)
ttl: Time-to-live in seconds for requests before automatic timeout
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Maximum number of partial prefill operations allowed
max_long_partial_prefills: Maximum number of long-running partial prefill operations
long_prefill_token_threshold: Token count threshold to classify as long prefill
Initializes:
- Thread synchronization primitives (mutex, condition variables)
- Request and response tracking structures
- Chunked prefill configuration parameters
- Request queue management system
Note:
- Uses thread-safe operations for concurrent access
- Automatically recycles expired requests based on TTL
- Supports both batched and individual request processing
"""
self.max_size = max_size
self.ttl = ttl
self.mutex = threading.Lock()
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.ids_read_cursor = 0
self.ids: List[str] = list()
self.requests: Dict[str, ScheduledRequest] = dict()
self.responses: Dict[str, List[ScheduledResponse]] = dict()
self.wait_request_timeout = 10
self.wait_response_timeout = 0.001
self.requests_not_empty = threading.Condition(self.mutex)
self.responses_not_empty = threading.Condition(self.mutex)
def reset(self):
"""
Reset the local scheduler to its initial empty state by:
1. Resetting the request ID tracking cursor to 0
2. Clearing all stored request IDs
3. Clearing all pending requests
4. Clearing all cached responses
This method is thread-safe and should be called when:
- The scheduler needs to be cleanly restarted
- Recovering from critical errors
- Preparing for graceful shutdown
Effects:
- Resets the ids_read_cursor to 0 (request processing position)
- Clears the ids list tracking all request IDs
- Clears the requests dictionary tracking pending requests
- Clears the responses dictionary tracking received responses
Note:
- Uses the scheduler's mutex to ensure thread safety
- Does not affect the scheduler's configuration parameters (max_size, ttl, etc.)
- After reset, the scheduler will be empty but still operational
"""
with self.mutex:
self.ids_read_cursor = 0
self.ids = list()
self.requests = dict()
self.responses = dict()
llm_logger.info("Scheduler has been reset")
def _recycle(self, request_id: Optional[str] = None):
"""
Clean up expired or completed requests to free memory.
Args:
request_id: Optional specific request ID to remove.
If None, removes all expired requests.
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if (now - request.schedule_time < self.ttl):
break
expired_ids.append(request.request_id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
def put_requests(
self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
"""
Add new requests to the scheduler queue.
Args:
requests: List of Request objects to enqueue
Returns:
List of tuples containing (request_id, error_message) for each request.
error_message is None for successful enqueues.
"""
with self.mutex:
self._recycle()
if self.max_size > 0 and len(
self.requests) + len(requests) > self.max_size:
msg = f"Exceeding the max length of the local scheduler (max_size={self.max_size})"
return [(request.request_id, msg) for request in requests]
valid_ids = []
duplicated_ids = []
for request in requests:
if request.request_id in self.requests:
duplicated_ids.append(request.request_id)
else:
scheduled_request = ScheduledRequest(request)
self.requests[
scheduled_request.request_id] = scheduled_request
valid_ids.append(scheduled_request.request_id)
self.ids += valid_ids
self.requests_not_empty.notify_all()
llm_logger.info(f"Scheduler has enqueued some requests: {valid_ids}")
if len(duplicated_ids) > 0:
llm_logger.warning(
f"Scheduler has received some duplicated requests: {duplicated_ids}"
)
results = [(request_id, None) for request_id in valid_ids]
results += [(request_id, "duplicated request_id")
for request_id in duplicated_ids]
return results
def calc_required_blocks(self, token_num, block_size):
"""
Calculate the number of blocks needed for a given number of tokens.
Args:
token_num: Number of tokens
block_size: Size of each block
Returns:
Number of blocks required (rounded up)
"""
return (token_num + block_size - 1) // block_size
def get_requests(self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1) -> List[Request]:
"""
Retrieve requests from the scheduler based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
return []
with self.requests_not_empty:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor:self.ids_read_cursor +
batch], self.wait_request_timeout)
required_total_blocks = 0
current_prefill_tokens = 0
requests: List[Request] = []
long_partial_requests, short_partial_requests = 0, 0
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(
request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
# 长请求
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
break
else:
short_partial_requests += 1
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
break
else:
if current_prefill_tokens > max_num_batched_tokens:
break
requests.append(request.raw)
self.ids_read_cursor += len(requests)
if len(batch_ids) > 0 and len(requests) == 0:
llm_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
llm_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
return requests
def put_results(self, results: List[RequestOutput]):
"""
Add processing results back to the scheduler.
Args:
results: List of RequestOutput objects containing results
"""
responses: List[ScheduledResponse] = [
ScheduledResponse(result) for result in results
]
finished_responses = [
response.request_id for response in responses if response.finished
]
if len(finished_responses) > 0:
llm_logger.info(
f"Scheduler has received some finished responses: {finished_responses}"
)
with self.mutex:
for response in responses:
if response.request_id not in self.requests:
llm_logger.warning(
f"Scheduler has received a expired response: {[response.request_id]}"
)
continue
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
continue
self.responses[response.request_id].append(response)
self.responses_not_empty.notify_all()
def get_results(self) -> Dict[str, List[RequestOutput]]:
"""
Retrieve all available results from the scheduler and clean up completed requests.
This method:
- Waits for new responses using a condition variable
- Returns all currently available responses
- Automatically removes completed requests from the scheduler
- Logs finished requests
Returns:
Dict[str, List[RequestOutput]]:
A dictionary where:
- Key is the request ID
- Value is a list of RequestOutput objects for that request
Completed requests are automatically removed from the scheduler
Note:
- Thread-safe operation using condition variables
- Has a short timeout (0.001s) to avoid blocking
- Automatically recycles completed requests to free memory
- Logs finished requests via llm_logger
"""
def _get_results():
responses = self.responses
self.responses = dict()
return responses
with self.responses_not_empty:
responses = self.responses_not_empty.wait_for(
_get_results, self.wait_response_timeout)
results = dict()
for request_id, resps in responses.items():
finished = False
results[request_id] = []
for resp in resps:
results[request_id].append(resp.raw)
finished |= resp.finished
if finished:
self._recycle(request_id)
llm_logger.info(
f"Scheduler has pulled a finished response: {[request_id]}"
)
return results