mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			285 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			285 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import redis
 | |
| 
 | |
| from fastdeploy.utils import llm_logger
 | |
| 
 | |
| from .global_scheduler import GlobalScheduler
 | |
| from .local_scheduler import LocalScheduler
 | |
| from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
 | |
| 
 | |
| 
 | |
| class LocalSchedulerConfig:
 | |
|     """
 | |
|     Configuration class for LocalScheduler.
 | |
| 
 | |
|     Attributes:
 | |
|         max_size: Maximum number of concurrent requests (-1 for unlimited)
 | |
|         ttl: Time-to-live in seconds for request expiration
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         max_size: int = -1,
 | |
|         ttl: int = 900,
 | |
|         max_model_len: int = 8192,
 | |
|         enable_chunked_prefill: bool = False,
 | |
|         max_num_partial_prefills: int = 1,
 | |
|         max_long_partial_prefills: int = 1,
 | |
|         long_prefill_token_threshold: int = 0,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         """
 | |
|         Initialize LocalScheduler configuration.
 | |
| 
 | |
|         Args:
 | |
|             max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
 | |
|             ttl: Time-to-live in seconds for request expiration (default 900s)
 | |
|             max_model_len: Maximum model context length in tokens
 | |
|             enable_chunked_prefill: Whether to enable chunked prefill processing
 | |
|             max_num_partial_prefills: Max partial prefill operations allowed
 | |
|             max_long_partial_prefills: Max long-running partial prefill ops
 | |
|             long_prefill_token_threshold: Token count threshold for long prefill
 | |
|             **kwargs: Additional unused arguments (for forward compatibility)
 | |
| 
 | |
|         Note:
 | |
|             - If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
 | |
|             - See LocalScheduler class for implementation details
 | |
|         """
 | |
|         self.max_size = max_size
 | |
|         self.ttl = ttl
 | |
| 
 | |
|         self.max_model_len = max_model_len
 | |
|         self.enable_chunked_prefill = enable_chunked_prefill
 | |
|         self.max_num_partial_prefills = max_num_partial_prefills
 | |
|         self.max_long_partial_prefills = max_long_partial_prefills
 | |
|         self.long_prefill_token_threshold = long_prefill_token_threshold
 | |
|         if self.long_prefill_token_threshold == 0:
 | |
|             self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
 | |
| 
 | |
|     def check(self):
 | |
|         """
 | |
|         Validate the configuration values.
 | |
| 
 | |
|         Currently performs no validation as all values are acceptable.
 | |
|         """
 | |
|         pass
 | |
| 
 | |
|     def print(self):
 | |
|         """
 | |
|         Print the current configuration to logs.
 | |
|         """
 | |
|         llm_logger.info("LocalScheduler Configuration Information :")
 | |
|         for k, v in self.__dict__.items():
 | |
|             llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
 | |
|         llm_logger.info("=============================================================")
 | |
| 
 | |
| 
 | |
| class GlobalSchedulerConfig:
 | |
|     """
 | |
|     Configuration class for GlobalScheduler (Redis-based).
 | |
| 
 | |
|     Attributes:
 | |
|         host: Redis server hostname
 | |
|         port: Redis server port
 | |
|         db: Redis database number
 | |
|         password: Optional Redis password
 | |
|         topic: Namespace prefix for queues
 | |
|         ttl: Time-to-live in seconds for Redis keys
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         host: str = "127.0.0.1",
 | |
|         port: int = 6379,
 | |
|         db: int = 0,
 | |
|         password=None,
 | |
|         topic: str = "default",
 | |
|         ttl: int = 900,
 | |
|         min_load_score: float = 3,
 | |
|         max_model_len: int = 8192,
 | |
|         load_shards_num: int = 1,
 | |
|         enable_chunked_prefill: bool = False,
 | |
|         max_num_partial_prefills: int = 1,
 | |
|         max_long_partial_prefills: int = 1,
 | |
|         long_prefill_token_threshold: int = 0,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         """
 | |
|         Initialize GlobalScheduler (Redis-based) configuration.
 | |
| 
 | |
|         Args:
 | |
|             host: Redis server hostname (default "127.0.0.1")
 | |
|             port: Redis server port (default 6379)
 | |
|             db: Redis database number (default 0)
 | |
|             password: Optional Redis password
 | |
|             topic: Namespace prefix for queues (default "default")
 | |
|             ttl: Time-to-live in seconds for Redis keys (default 900s)
 | |
|             min_load_score: Minimum load score for task assignment (default 3)
 | |
|             max_model_len: Maximum model context length in tokens
 | |
|             load_shards_num: Number of load balancing shards
 | |
|             enable_chunked_prefill: Whether to enable chunked prefill processing
 | |
|             max_num_partial_prefills: Max partial prefill operations allowed
 | |
|             max_long_partial_prefills: Max long-running partial prefill ops
 | |
|             long_prefill_token_threshold: Token count threshold for long prefill
 | |
|             **kwargs: Additional unused arguments (for forward compatibility)
 | |
| 
 | |
|         Note:
 | |
|             - If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
 | |
|             - See GlobalScheduler class for implementation details
 | |
|         """
 | |
|         self.host = host
 | |
|         self.port = port
 | |
|         self.db = db
 | |
|         self.password = password
 | |
|         self.topic = topic
 | |
|         self.ttl = ttl
 | |
|         self.min_load_score = min_load_score
 | |
|         self.load_shards_num = load_shards_num
 | |
| 
 | |
|         self.max_model_len = max_model_len
 | |
|         self.enable_chunked_prefill = enable_chunked_prefill
 | |
|         self.max_num_partial_prefills = max_num_partial_prefills
 | |
|         self.max_long_partial_prefills = max_long_partial_prefills
 | |
|         self.long_prefill_token_threshold = long_prefill_token_threshold
 | |
|         if self.long_prefill_token_threshold == 0:
 | |
|             self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
 | |
| 
 | |
|     def check(self):
 | |
|         """
 | |
|         Validate the configuration by testing Redis connection.
 | |
| 
 | |
|         Raises:
 | |
|             Exception: If connection to Redis fails
 | |
|         """
 | |
| 
 | |
|         if self.ttl <= 0:
 | |
|             raise ValueError("ttl should be greater than 60")
 | |
|         if self.min_load_score < 1:
 | |
|             raise ValueError("min_load_score should be greater than 0")
 | |
|         if self.load_shards_num < 1:
 | |
|             raise ValueError("load_shards_num should be greater than 0")
 | |
| 
 | |
|         r = redis.Redis(self.host, self.port, self.db, self.password)
 | |
|         try:
 | |
|             response = r.ping()
 | |
|             if not response:
 | |
|                 raise Exception("connect to redis failed")
 | |
|         finally:
 | |
|             r.close()
 | |
| 
 | |
|     def print(self):
 | |
|         """
 | |
|         Print the current configuration to logs.
 | |
|         """
 | |
|         llm_logger.info("GlobalScheduler Configuration Information :")
 | |
|         password = self.password
 | |
|         self.password = "******"
 | |
|         for k, v in self.__dict__.items():
 | |
|             llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
 | |
|         self.password = password
 | |
|         llm_logger.info("=============================================================")
 | |
| 
 | |
| 
 | |
| class SchedulerConfig:
 | |
|     """
 | |
|     Factory class for scheduler configurations.
 | |
| 
 | |
|     Creates appropriate config based on scheduler type (local/global).
 | |
|     """
 | |
| 
 | |
|     def __init__(self, name="local", **kwargs):
 | |
|         """
 | |
|         Initialize scheduler configuration factory.
 | |
| 
 | |
|         Args:
 | |
|             name: Scheduler type ("local" for LocalScheduler or "global" for GlobalScheduler)
 | |
|             **kwargs: Configuration parameters for the specific scheduler type
 | |
| 
 | |
|         Initializes:
 | |
|             - Appropriate config object based on scheduler type
 | |
|             - Validates configuration parameters
 | |
| 
 | |
|         Raises:
 | |
|             Exception: If invalid scheduler type is specified
 | |
|         """
 | |
|         self.name = name
 | |
|         self.config = None
 | |
| 
 | |
|         if name == "local":
 | |
|             self.config = LocalSchedulerConfig(**kwargs)
 | |
| 
 | |
|         if name == "global":
 | |
|             self.config = GlobalSchedulerConfig(**kwargs)
 | |
| 
 | |
|         if name == "splitwise":
 | |
|             self.config = SplitWiseSchedulerConfig(**kwargs)
 | |
| 
 | |
|     def check(self):
 | |
|         """
 | |
|         Validate the configuration.
 | |
| 
 | |
|         Raises:
 | |
|             Exception: If invalid scheduler type is specified
 | |
|         """
 | |
|         if self.name not in ["local", "global", "splitwise"]:
 | |
|             raise Exception(f"Unknown scheduler type {self.name}")
 | |
| 
 | |
|         self.config.check()
 | |
| 
 | |
|     def print(self):
 | |
|         """
 | |
|         Print the current configuration to logs.
 | |
|         """
 | |
|         self.config.print()
 | |
| 
 | |
|     def scheduler(self):
 | |
|         """
 | |
|         Create a scheduler instance based on the configuration.
 | |
| 
 | |
|         Returns:
 | |
|             Initialized scheduler instance (LocalScheduler or GlobalScheduler)
 | |
|         """
 | |
| 
 | |
|         if self.name == "global":
 | |
|             return GlobalScheduler(
 | |
|                 host=self.config.host,
 | |
|                 port=self.config.port,
 | |
|                 db=self.config.db,
 | |
|                 password=self.config.password,
 | |
|                 topic=self.config.topic,
 | |
|                 ttl=self.config.ttl,
 | |
|                 min_load_score=self.config.min_load_score,
 | |
|                 load_shards_num=self.config.load_shards_num,
 | |
|                 enable_chunked_prefill=self.config.enable_chunked_prefill,
 | |
|                 max_num_partial_prefills=self.config.max_num_partial_prefills,
 | |
|                 max_long_partial_prefills=self.config.max_long_partial_prefills,
 | |
|                 long_prefill_token_threshold=self.config.long_prefill_token_threshold,
 | |
|             )
 | |
| 
 | |
|         if self.name == "splitwise":
 | |
|             return SplitWiseScheduler(self.config)
 | |
| 
 | |
|         return LocalScheduler(
 | |
|             max_size=self.config.max_size,
 | |
|             ttl=self.config.ttl,
 | |
|             enable_chunked_prefill=self.config.enable_chunked_prefill,
 | |
|             max_num_partial_prefills=self.config.max_num_partial_prefills,
 | |
|             max_long_partial_prefills=self.config.max_long_partial_prefills,
 | |
|             long_prefill_token_threshold=self.config.long_prefill_token_threshold,
 | |
|         )
 | 
