mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			353 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			353 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import threading
 | |
| import time
 | |
| from typing import Dict, List, Optional, Tuple
 | |
| 
 | |
| from fastdeploy.engine.request import Request, RequestOutput
 | |
| from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
 | |
| from fastdeploy.utils import scheduler_logger
 | |
| 
 | |
| 
 | |
| class LocalScheduler:
 | |
|     """
 | |
|     A local in-memory task scheduler for request/response management.
 | |
| 
 | |
|     This class provides functionality for:
 | |
|     - Enqueuing and dequeuing requests
 | |
|     - Managing request lifecycle with TTL
 | |
|     - Handling request/response flow
 | |
|     - Thread-safe operations with condition variables
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         max_size: int,
 | |
|         ttl: int,
 | |
|         enable_chunked_prefill: bool,
 | |
|         max_num_partial_prefills: int,
 | |
|         max_long_partial_prefills: int,
 | |
|         long_prefill_token_threshold: int,
 | |
|     ):
 | |
|         """
 | |
|         Initializes a local in-memory scheduler for managing inference requests.
 | |
| 
 | |
|         Args:
 | |
|             max_size: Maximum number of concurrent requests the scheduler can handle (0 for unlimited)
 | |
|             ttl: Time-to-live in seconds for requests before automatic timeout
 | |
|             enable_chunked_prefill: Whether to enable chunked prefill processing
 | |
|             max_num_partial_prefills: Maximum number of partial prefill operations allowed
 | |
|             max_long_partial_prefills: Maximum number of long-running partial prefill operations
 | |
|             long_prefill_token_threshold: Token count threshold to classify as long prefill
 | |
| 
 | |
|         Initializes:
 | |
|             - Thread synchronization primitives (mutex, condition variables)
 | |
|             - Request and response tracking structures
 | |
|             - Chunked prefill configuration parameters
 | |
|             - Request queue management system
 | |
| 
 | |
|         Note:
 | |
|             - Uses thread-safe operations for concurrent access
 | |
|             - Automatically recycles expired requests based on TTL
 | |
|             - Supports both batched and individual request processing
 | |
|         """
 | |
|         self.max_size = max_size
 | |
|         self.ttl = ttl
 | |
|         self.mutex = threading.Lock()
 | |
| 
 | |
|         self.enable_chunked_prefill = enable_chunked_prefill
 | |
|         self.max_num_partial_prefills = max_num_partial_prefills
 | |
|         self.max_long_partial_prefills = max_long_partial_prefills
 | |
|         self.long_prefill_token_threshold = long_prefill_token_threshold
 | |
| 
 | |
|         self.ids_read_cursor = 0
 | |
|         self.ids: List[str] = list()
 | |
| 
 | |
|         self.requests: Dict[str, ScheduledRequest] = dict()
 | |
|         self.responses: Dict[str, List[ScheduledResponse]] = dict()
 | |
| 
 | |
|         self.wait_request_timeout = 10
 | |
|         self.wait_response_timeout = 0.001
 | |
| 
 | |
|         self.requests_not_empty = threading.Condition(self.mutex)
 | |
|         self.responses_not_empty = threading.Condition(self.mutex)
 | |
| 
 | |
|     def reset(self):
 | |
|         """
 | |
|         Reset the local scheduler to its initial empty state by:
 | |
|         1. Resetting the request ID tracking cursor to 0
 | |
|         2. Clearing all stored request IDs
 | |
|         3. Clearing all pending requests
 | |
|         4. Clearing all cached responses
 | |
| 
 | |
|         This method is thread-safe and should be called when:
 | |
|         - The scheduler needs to be cleanly restarted
 | |
|         - Recovering from critical errors
 | |
|         - Preparing for graceful shutdown
 | |
| 
 | |
|         Effects:
 | |
|         - Resets the ids_read_cursor to 0 (request processing position)
 | |
|         - Clears the ids list tracking all request IDs
 | |
|         - Clears the requests dictionary tracking pending requests
 | |
|         - Clears the responses dictionary tracking received responses
 | |
| 
 | |
|         Note:
 | |
|         - Uses the scheduler's mutex to ensure thread safety
 | |
|         - Does not affect the scheduler's configuration parameters (max_size, ttl, etc.)
 | |
|         - After reset, the scheduler will be empty but still operational
 | |
|         """
 | |
|         with self.mutex:
 | |
|             self.ids_read_cursor = 0
 | |
|             self.ids = list()
 | |
|             self.requests = dict()
 | |
|             self.responses = dict()
 | |
|         scheduler_logger.info("Scheduler has been reset")
 | |
| 
 | |
|     def _recycle(self, request_id: Optional[str] = None):
 | |
|         """
 | |
|         Clean up expired or completed requests to free memory.
 | |
| 
 | |
|         Args:
 | |
|             request_id: Optional specific request ID to remove.
 | |
|                        If None, removes all expired requests.
 | |
|         """
 | |
|         if request_id is not None:
 | |
|             self.requests.pop(request_id, None)
 | |
|             self.responses.pop(request_id, None)
 | |
|             self.ids.pop(self.ids.index(request_id))
 | |
|             self.ids_read_cursor -= 1
 | |
|             return
 | |
| 
 | |
|         if self.max_size <= 0:
 | |
|             return
 | |
| 
 | |
|         if len(self.requests) <= self.max_size:
 | |
|             return
 | |
| 
 | |
|         now = time.time()
 | |
|         expired_ids = []
 | |
|         for request_id in self.ids:
 | |
|             request = self.requests[request_id]
 | |
|             if now - request.schedule_time < self.ttl:
 | |
|                 break
 | |
|             expired_ids.append(request.request_id)
 | |
| 
 | |
|         for i, expired_id in enumerate(expired_ids):
 | |
|             self.requests.pop(expired_id, None)
 | |
|             self.responses.pop(expired_id, None)
 | |
|             self.ids.pop(i)
 | |
| 
 | |
|         if len(expired_ids) > 0:
 | |
|             if len(expired_ids) - 1 >= self.ids_read_cursor:
 | |
|                 self.ids_read_cursor = 0
 | |
|             else:
 | |
|                 self.ids_read_cursor -= len(expired_ids)
 | |
| 
 | |
|     def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
 | |
|         """
 | |
|         Add new requests to the scheduler queue.
 | |
| 
 | |
|         Args:
 | |
|             requests: List of Request objects to enqueue
 | |
| 
 | |
|         Returns:
 | |
|             List of tuples containing (request_id, error_message) for each request.
 | |
|             error_message is None for successful enqueues.
 | |
|         """
 | |
|         with self.mutex:
 | |
|             self._recycle()
 | |
|             if self.max_size > 0 and len(self.requests) + len(requests) > self.max_size:
 | |
|                 msg = f"Exceeding the max length of the local scheduler (max_size={self.max_size})"
 | |
|                 return [(request.request_id, msg) for request in requests]
 | |
| 
 | |
|             valid_ids = []
 | |
|             duplicated_ids = []
 | |
|             for request in requests:
 | |
|                 if request.request_id in self.requests:
 | |
|                     duplicated_ids.append(request.request_id)
 | |
|                 else:
 | |
|                     scheduled_request = ScheduledRequest(request)
 | |
|                     self.requests[scheduled_request.request_id] = scheduled_request
 | |
|                     valid_ids.append(scheduled_request.request_id)
 | |
| 
 | |
|             self.ids += valid_ids
 | |
|             self.requests_not_empty.notify_all()
 | |
|         scheduler_logger.info(f"Scheduler has enqueued some requests: {valid_ids}")
 | |
| 
 | |
|         if len(duplicated_ids) > 0:
 | |
|             scheduler_logger.warning(f"Scheduler has received some duplicated requests: {duplicated_ids}")
 | |
| 
 | |
|         results = [(request_id, None) for request_id in valid_ids]
 | |
|         results += [(request_id, "duplicated request_id") for request_id in duplicated_ids]
 | |
|         return results
 | |
| 
 | |
|     def calc_required_blocks(self, token_num, block_size):
 | |
|         """
 | |
|         Calculate the number of blocks needed for a given number of tokens.
 | |
| 
 | |
|         Args:
 | |
|             token_num: Number of tokens
 | |
|             block_size: Size of each block
 | |
| 
 | |
|         Returns:
 | |
|             Number of blocks required (rounded up)
 | |
|         """
 | |
|         return (token_num + block_size - 1) // block_size
 | |
| 
 | |
|     def get_requests(
 | |
|         self,
 | |
|         available_blocks,
 | |
|         block_size,
 | |
|         reserved_output_blocks,
 | |
|         max_num_batched_tokens,
 | |
|         batch=1,
 | |
|     ) -> List[Request]:
 | |
|         """
 | |
|         Retrieve requests from the scheduler based on available resources.
 | |
| 
 | |
|         Args:
 | |
|             available_blocks: Number of available processing blocks
 | |
|             block_size: Size of each processing block
 | |
|             reserved_output_blocks: Blocks reserved for output
 | |
|             max_num_batched_tokens: Maximum tokens that can be batched
 | |
|             batch: Preferred batch size
 | |
| 
 | |
|         Returns:
 | |
|             List of Request objects ready for processing
 | |
|         """
 | |
|         if available_blocks <= reserved_output_blocks or batch < 1:
 | |
|             scheduler_logger.debug(
 | |
|                 f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
 | |
|                 f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
 | |
|                 f"max_num_batched_tokens={max_num_batched_tokens}"
 | |
|             )
 | |
|             return []
 | |
| 
 | |
|         with self.requests_not_empty:
 | |
|             batch_ids = self.requests_not_empty.wait_for(
 | |
|                 lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
 | |
|                 self.wait_request_timeout,
 | |
|             )
 | |
| 
 | |
|             required_total_blocks = 0
 | |
|             current_prefill_tokens = 0
 | |
|             requests: List[Request] = []
 | |
|             long_partial_requests, short_partial_requests = 0, 0
 | |
|             for request_id in batch_ids:
 | |
|                 request = self.requests[request_id]
 | |
|                 required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
 | |
|                 current_prefill_tokens += request.prompt_tokens_ids_len
 | |
|                 required_total_blocks += required_input_blocks + reserved_output_blocks
 | |
|                 if required_total_blocks > available_blocks:
 | |
|                     break
 | |
| 
 | |
|                 if self.enable_chunked_prefill:
 | |
|                     if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
 | |
|                         # 长请求
 | |
|                         long_partial_requests += 1
 | |
|                         if long_partial_requests > self.max_long_partial_prefills:
 | |
|                             break
 | |
|                     else:
 | |
|                         short_partial_requests += 1
 | |
| 
 | |
|                     if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
 | |
|                         break
 | |
|                 else:
 | |
|                     if current_prefill_tokens > max_num_batched_tokens:
 | |
|                         break
 | |
| 
 | |
|                 requests.append(request.raw)
 | |
|             self.ids_read_cursor += len(requests)
 | |
| 
 | |
|         if len(batch_ids) > 0 and len(requests) == 0:
 | |
|             scheduler_logger.debug(f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}")
 | |
| 
 | |
|         if len(requests) > 0:
 | |
|             scheduler_logger.info(f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
 | |
| 
 | |
|         return requests
 | |
| 
 | |
|     def put_results(self, results: List[RequestOutput]):
 | |
|         """
 | |
|         Add processing results back to the scheduler.
 | |
| 
 | |
|         Args:
 | |
|             results: List of RequestOutput objects containing results
 | |
|         """
 | |
|         responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
 | |
| 
 | |
|         finished_responses = [response.request_id for response in responses if response.finished]
 | |
|         if len(finished_responses) > 0:
 | |
|             scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
 | |
| 
 | |
|         with self.mutex:
 | |
|             for response in responses:
 | |
|                 if response.request_id not in self.requests:
 | |
|                     scheduler_logger.warning(f"Scheduler has received a expired response: {[response.request_id]}")
 | |
|                     continue
 | |
| 
 | |
|                 if response.request_id not in self.responses:
 | |
|                     self.responses[response.request_id] = [response]
 | |
|                     continue
 | |
|                 self.responses[response.request_id].append(response)
 | |
|             self.responses_not_empty.notify_all()
 | |
| 
 | |
|     def get_results(self) -> Dict[str, List[RequestOutput]]:
 | |
|         """
 | |
|         Retrieve all available results from the scheduler and clean up completed requests.
 | |
| 
 | |
|         This method:
 | |
|         - Waits for new responses using a condition variable
 | |
|         - Returns all currently available responses
 | |
|         - Automatically removes completed requests from the scheduler
 | |
|         - Logs finished requests
 | |
| 
 | |
|         Returns:
 | |
|             Dict[str, List[RequestOutput]]:
 | |
|                 A dictionary where:
 | |
|                 - Key is the request ID
 | |
|                 - Value is a list of RequestOutput objects for that request
 | |
|                 Completed requests are automatically removed from the scheduler
 | |
| 
 | |
|         Note:
 | |
|             - Thread-safe operation using condition variables
 | |
|             - Has a short timeout (0.001s) to avoid blocking
 | |
|             - Automatically recycles completed requests to free memory
 | |
|             - Logs finished requests via scheduler_logger
 | |
|         """
 | |
| 
 | |
|         def _get_results():
 | |
|             responses = self.responses
 | |
|             self.responses = dict()
 | |
|             return responses
 | |
| 
 | |
|         with self.responses_not_empty:
 | |
|             responses = self.responses_not_empty.wait_for(_get_results, self.wait_response_timeout)
 | |
| 
 | |
|             results = dict()
 | |
|             for request_id, resps in responses.items():
 | |
|                 finished = False
 | |
|                 results[request_id] = []
 | |
|                 for resp in resps:
 | |
|                     results[request_id].append(resp.raw)
 | |
|                     finished |= resp.finished
 | |
| 
 | |
|                 if finished:
 | |
|                     self._recycle(request_id)
 | |
|                     scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}")
 | |
|             return results
 | 
