""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import threading import time from typing import Dict, List, Optional, Tuple from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse from fastdeploy.utils import llm_logger class LocalScheduler(object): """ A local in-memory task scheduler for request/response management. This class provides functionality for: - Enqueuing and dequeuing requests - Managing request lifecycle with TTL - Handling request/response flow - Thread-safe operations with condition variables """ def __init__( self, max_size: int, ttl: int, enable_chunked_prefill: bool, max_num_partial_prefills: int, max_long_partial_prefills: int, long_prefill_token_threshold: int, ): """ Initializes a local in-memory scheduler for managing inference requests. Args: max_size: Maximum number of concurrent requests the scheduler can handle (0 for unlimited) ttl: Time-to-live in seconds for requests before automatic timeout enable_chunked_prefill: Whether to enable chunked prefill processing max_num_partial_prefills: Maximum number of partial prefill operations allowed max_long_partial_prefills: Maximum number of long-running partial prefill operations long_prefill_token_threshold: Token count threshold to classify as long prefill Initializes: - Thread synchronization primitives (mutex, condition variables) - Request and response tracking structures - Chunked prefill configuration parameters - Request queue management system Note: - Uses thread-safe operations for concurrent access - Automatically recycles expired requests based on TTL - Supports both batched and individual request processing """ self.max_size = max_size self.ttl = ttl self.mutex = threading.Lock() self.enable_chunked_prefill = enable_chunked_prefill self.max_num_partial_prefills = max_num_partial_prefills self.max_long_partial_prefills = max_long_partial_prefills self.long_prefill_token_threshold = long_prefill_token_threshold self.ids_read_cursor = 0 self.ids: List[str] = list() self.requests: Dict[str, ScheduledRequest] = dict() self.responses: Dict[str, List[ScheduledResponse]] = dict() self.wait_request_timeout = 10 self.wait_response_timeout = 0.001 self.requests_not_empty = threading.Condition(self.mutex) self.responses_not_empty = threading.Condition(self.mutex) def reset(self): """ Reset the local scheduler to its initial empty state by: 1. Resetting the request ID tracking cursor to 0 2. Clearing all stored request IDs 3. Clearing all pending requests 4. Clearing all cached responses This method is thread-safe and should be called when: - The scheduler needs to be cleanly restarted - Recovering from critical errors - Preparing for graceful shutdown Effects: - Resets the ids_read_cursor to 0 (request processing position) - Clears the ids list tracking all request IDs - Clears the requests dictionary tracking pending requests - Clears the responses dictionary tracking received responses Note: - Uses the scheduler's mutex to ensure thread safety - Does not affect the scheduler's configuration parameters (max_size, ttl, etc.) - After reset, the scheduler will be empty but still operational """ with self.mutex: self.ids_read_cursor = 0 self.ids = list() self.requests = dict() self.responses = dict() llm_logger.info("Scheduler has been reset") def _recycle(self, request_id: Optional[str] = None): """ Clean up expired or completed requests to free memory. Args: request_id: Optional specific request ID to remove. If None, removes all expired requests. """ if request_id is not None: self.requests.pop(request_id, None) self.responses.pop(request_id, None) self.ids.pop(self.ids.index(request_id)) self.ids_read_cursor -= 1 return if self.max_size <= 0: return if len(self.requests) <= self.max_size: return now = time.time() expired_ids = [] for request_id in self.ids: request = self.requests[request_id] if (now - request.schedule_time < self.ttl): break expired_ids.append(request.request_id) for i, expired_id in enumerate(expired_ids): self.requests.pop(expired_id, None) self.responses.pop(expired_id, None) self.ids.pop(i) if len(expired_ids) > 0: if len(expired_ids) - 1 >= self.ids_read_cursor: self.ids_read_cursor = 0 else: self.ids_read_cursor -= len(expired_ids) def put_requests( self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]: """ Add new requests to the scheduler queue. Args: requests: List of Request objects to enqueue Returns: List of tuples containing (request_id, error_message) for each request. error_message is None for successful enqueues. """ with self.mutex: self._recycle() if self.max_size > 0 and len( self.requests) + len(requests) > self.max_size: msg = f"Exceeding the max length of the local scheduler (max_size={self.max_size})" return [(request.request_id, msg) for request in requests] valid_ids = [] duplicated_ids = [] for request in requests: if request.request_id in self.requests: duplicated_ids.append(request.request_id) else: scheduled_request = ScheduledRequest(request) self.requests[ scheduled_request.request_id] = scheduled_request valid_ids.append(scheduled_request.request_id) self.ids += valid_ids self.requests_not_empty.notify_all() llm_logger.info(f"Scheduler has enqueued some requests: {valid_ids}") if len(duplicated_ids) > 0: llm_logger.warning( f"Scheduler has received some duplicated requests: {duplicated_ids}" ) results = [(request_id, None) for request_id in valid_ids] results += [(request_id, "duplicated request_id") for request_id in duplicated_ids] return results def calc_required_blocks(self, token_num, block_size): """ Calculate the number of blocks needed for a given number of tokens. Args: token_num: Number of tokens block_size: Size of each block Returns: Number of blocks required (rounded up) """ return (token_num + block_size - 1) // block_size def get_requests(self, available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch=1) -> List[Request]: """ Retrieve requests from the scheduler based on available resources. Args: available_blocks: Number of available processing blocks block_size: Size of each processing block reserved_output_blocks: Blocks reserved for output max_num_batched_tokens: Maximum tokens that can be batched batch: Preferred batch size Returns: List of Request objects ready for processing """ if available_blocks <= reserved_output_blocks or batch < 1: llm_logger.debug( f"Scheduler's resource are insufficient: available_blocks={available_blocks} " f"reserved_output_blocks={reserved_output_blocks} batch={batch} " f"max_num_batched_tokens={max_num_batched_tokens}") return [] with self.requests_not_empty: batch_ids = self.requests_not_empty.wait_for( lambda: self.ids[self.ids_read_cursor:self.ids_read_cursor + batch], self.wait_request_timeout) required_total_blocks = 0 current_prefill_tokens = 0 requests: List[Request] = [] long_partial_requests, short_partial_requests = 0, 0 for request_id in batch_ids: request = self.requests[request_id] required_input_blocks = self.calc_required_blocks( request.prompt_tokens_ids_len, block_size) current_prefill_tokens += request.prompt_tokens_ids_len required_total_blocks += required_input_blocks + reserved_output_blocks if required_total_blocks > available_blocks: break if self.enable_chunked_prefill: if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: # 长请求 long_partial_requests += 1 if long_partial_requests > self.max_long_partial_prefills: break else: short_partial_requests += 1 if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: break else: if current_prefill_tokens > max_num_batched_tokens: break requests.append(request.raw) self.ids_read_cursor += len(requests) if len(batch_ids) > 0 and len(requests) == 0: llm_logger.debug( f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" ) if len(requests) > 0: llm_logger.info( f"Scheduler has pulled some request: {[request.request_id for request in requests]}" ) return requests def put_results(self, results: List[RequestOutput]): """ Add processing results back to the scheduler. Args: results: List of RequestOutput objects containing results """ responses: List[ScheduledResponse] = [ ScheduledResponse(result) for result in results ] finished_responses = [ response.request_id for response in responses if response.finished ] if len(finished_responses) > 0: llm_logger.info( f"Scheduler has received some finished responses: {finished_responses}" ) with self.mutex: for response in responses: if response.request_id not in self.requests: llm_logger.warning( f"Scheduler has received a expired response: {[response.request_id]}" ) continue if response.request_id not in self.responses: self.responses[response.request_id] = [response] continue self.responses[response.request_id].append(response) self.responses_not_empty.notify_all() def get_results(self) -> Dict[str, List[RequestOutput]]: """ Retrieve all available results from the scheduler and clean up completed requests. This method: - Waits for new responses using a condition variable - Returns all currently available responses - Automatically removes completed requests from the scheduler - Logs finished requests Returns: Dict[str, List[RequestOutput]]: A dictionary where: - Key is the request ID - Value is a list of RequestOutput objects for that request Completed requests are automatically removed from the scheduler Note: - Thread-safe operation using condition variables - Has a short timeout (0.001s) to avoid blocking - Automatically recycles completed requests to free memory - Logs finished requests via llm_logger """ def _get_results(): responses = self.responses self.responses = dict() return responses with self.responses_not_empty: responses = self.responses_not_empty.wait_for( _get_results, self.wait_response_timeout) results = dict() for request_id, resps in responses.items(): finished = False results[request_id] = [] for resp in resps: results[request_id].append(resp.raw) finished |= resp.finished if finished: self._recycle(request_id) llm_logger.info( f"Scheduler has pulled a finished response: {[request_id]}" ) return results