diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 17692129c..47ec60243 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -165,12 +165,6 @@ class LLMEngine(object): disable_any_whitespace=self.cfg.disable_any_whitespace, ) - def reset_scheduler(self): - """ - Reset the scheduler to its initial state. - """ - self.scheduler.reset() - def start(self, api_server_pid=None): """ Initializes the engine and starts its sub-services. diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 17e037dac..a0f261841 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -32,7 +32,8 @@ from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, CompletionResponse, - ErrorResponse) + ErrorResponse, + ControlSchedulerRequest) from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat from fastdeploy.entrypoints.openai.serving_completion import \ OpenAIServingCompletion @@ -273,10 +274,13 @@ def clear_load_weight(request: Request) -> Response: status_code=404) -def launch_api_server(args) -> None: +def launch_api_server() -> None: """ 启动http服务 """ + if not is_port_available(args.host, args.port): + raise Exception(f"The parameter `port`:{args.port} is already in use.") + api_server_logger.info( f"launch Fastdeploy api server... port: {args.port}") api_server_logger.info(f"args: {args.__dict__}") @@ -319,6 +323,11 @@ def run_metrics_server(): def launch_metrics_server(): """Metrics server running the sub thread""" + if not is_port_available(args.host, args.metrics_port): + raise Exception( + f"The parameter `metrics_port`:{args.metrics_port} is already in use." + ) + prom_dir = cleanup_prometheus_files(True) os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir metrics_server_thread = threading.Thread(target=run_metrics_server, @@ -339,10 +348,39 @@ def reset_scheduler(): if llm_engine is None: return Response("Engine not loaded", status_code=500) - llm_engine.reset_scheduler() + llm_engine.scheduler.reset_scheduler() return Response("Scheduler Reset Successfully", status_code=200) +@controller_app.post("/controller/scheduler") +def control_scheduler(request: ControlSchedulerRequest): + """ + Control the scheduler behavior with the given parameters. + """ + content = ErrorResponse(object="", message="Scheduler updated successfully", code=0) + + global llm_engine + if llm_engine is None: + content.message = "Engine is not loaded" + content.code = 500 + return JSONResponse(content=content.model_dump(), status_code=500) + + if request.reset: + llm_engine.scheduler.reset_scheduler() + + if request.load_shards_num or request.reallocate_shard: + if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config): + llm_engine.scheduler.update_config( + load_shards_num=request.load_shards_num, + reallocate=request.reallocate_shard) + else: + content.message="This scheduler doesn't support the `update_config()` method." + content.code=400 + return JSONResponse(content=content.model_dump(), status_code=400) + + return JSONResponse(content=content.model_dump(), status_code=200) + + def run_controller_server(): """ run controller server @@ -358,6 +396,11 @@ def launch_controller_server(): if args.controller_port < 0: return + if not is_port_available(args.host, args.controller_port): + raise Exception( + f"The parameter `controller_port`:{args.controller_port} is already in use." + ) + controller_server_thread = threading.Thread(target=run_controller_server, daemon=True) controller_server_thread.start() @@ -366,19 +409,13 @@ def launch_controller_server(): def main(): """main函数""" - if not is_port_available(args.host, args.port): - raise Exception(f"The parameter `port`:{args.port} is already in use.") - if not is_port_available(args.host, args.metrics_port): - raise Exception( - f"The parameter `metrics_port`:{args.metrics_port} is already in use." - ) if load_engine() is None: return launch_controller_server() launch_metrics_server() - launch_api_server(args) + launch_api_server() if __name__ == "__main__": diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index d4391e567..6a3b67c0d 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -542,3 +542,12 @@ class ChatCompletionRequest(BaseModel): ) return data + + +class ControlSchedulerRequest(BaseModel): + """ + Control scheduler request to the engine. + """ + reset: Optional[bool] = False + load_shards_num: Optional[int] = None + reallocate_shard: Optional[bool] = False \ No newline at end of file diff --git a/fastdeploy/scheduler/global_scheduler.py b/fastdeploy/scheduler/global_scheduler.py index f3eba6877..fb8cb3a8e 100644 --- a/fastdeploy/scheduler/global_scheduler.py +++ b/fastdeploy/scheduler/global_scheduler.py @@ -19,7 +19,6 @@ from typing import List, Optional, Dict, Tuple import traceback import threading import time -from datetime import datetime import random import uuid import crcmod @@ -28,7 +27,7 @@ from fastdeploy.scheduler.storage import AdaptedRedis from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse from fastdeploy.scheduler.workers import Workers, Task -from fastdeploy.utils import llm_logger +from fastdeploy.utils import scheduler_logger from fastdeploy.scheduler import utils @@ -51,7 +50,7 @@ class GlobalScheduler(object): topic: str, ttl: int, min_load_score: float, - load_shrads_num: int, + load_shards_num: int, enable_chunked_prefill: bool, max_num_partial_prefills: int, max_long_partial_prefills: int, @@ -68,7 +67,7 @@ class GlobalScheduler(object): topic: Base topic name for queue namespacing ttl: Time-to-live in seconds for Redis keys min_load_score: Minimum load score for task assignment - load_shrads_num: Number of shards for load balancing table + load_shards_num: Number of shards for load balancing table enable_chunked_prefill: Whether to enable chunked prefill processing max_num_partial_prefills: Maximum number of partial prefills allowed max_long_partial_prefills: Maximum number of long partial prefills allowed @@ -84,7 +83,7 @@ class GlobalScheduler(object): self.topic = topic self.ttl = ttl self.min_load_score = min_load_score - self.load_shrads_num = load_shrads_num + self.load_shards_num = load_shards_num self.enable_chunked_prefill = enable_chunked_prefill self.max_num_partial_prefills = max_num_partial_prefills @@ -97,14 +96,17 @@ class GlobalScheduler(object): self.crc16_mutex = threading.Lock() self.crc16 = crcmod.predefined.Crc('ccitt-false') self.load_slot_for_getting_request = 0 - self.load_start = 0 # const - self.load_num = 50 # const + self.load_offset = 0 # const + self.load_count = 50 # const + self.load_lookup_num = 5 # const + self.keep_alive_duration = 30 # const connection_pool = ConnectionPool( host=host, port=port, db=db, password=password, max_connections=10) self.client = AdaptedRedis(connection_pool=connection_pool) - self.name = self._generate_scheduler_name() + self.name, self.shard = self._generate_scheduler_name_and_shard() + self.keep_alive_workers = threading.Thread( target=self._keep_alive, daemon=True) self.keep_alive_workers.start() @@ -126,10 +128,32 @@ class GlobalScheduler(object): target=self._get_results_worker, daemon=True) self.get_response_workers.start() - llm_logger.info( + scheduler_logger.info( f"Scheduler: name={self.name} redis_version={self.client.version}") def _get_hash_slot(self, data: str) -> int: + """ + Calculate the hash slot for a given string using CRC16 algorithm. + + This method is thread-safe and used for consistent hashing in distributed scheduling. + It implements the same CRC16 algorithm (CCITT-FALSE variant) used by Redis Cluster. + + Args: + data: Input string to be hashed (typically a scheduler or request identifier) + + Returns: + int: A 16-bit hash value (0-65535) representing the calculated slot + + Implementation Details: + 1. Encodes input string as UTF-8 bytes + 2. Uses thread-safe CRC16 calculation with mutex protection + 3. Resets CRC state after each calculation + 4. Returns raw CRC value without modulo operation + + Note: + - The result is typically used with modulo operation for sharding (e.g. % num_shards) + - Matches Redis Cluster's slot distribution algorithm for compatibility + """ data = data.encode("utf-8") with self.crc16_mutex: self.crc16.update(data) @@ -149,58 +173,66 @@ class GlobalScheduler(object): """ return f"{self.topic}.ins.{scheduler_name}" - def _generate_scheduler_name(self) -> str: + def _generate_scheduler_name_and_shard(self) -> Tuple[str, int]: """ - Generate a unique name for this scheduler instance. + Generate a unique scheduler name and calculate its shard assignment. - Uses hostname/IP and timestamp to create a unique identifier, - then registers it in Redis with TTL. + This method: + 1. Creates a unique identifier using hostname/IP and timestamp + 2. Registers the name in Redis with TTL + 3. Calculates the shard assignment using consistent hashing + 4. Handles naming conflicts by appending incrementing suffixes Returns: - Unique scheduler name string + Tuple[str, int]: + - str: Unique scheduler name + - int: Assigned shard number (0 to load_shards_num-1) + + Implementation Details: + - Uses hostname/IP as base identifier, falls back to UUID if unavailable + - Implements conflict resolution with incrementing suffixes + - Registers name in Redis with keep-alive duration + - Calculates shard using CRC16 hash of the name + + Error Handling: + - Logs IP resolution failures + - Handles Redis registration conflicts gracefully + - Ensures unique name generation even in edge cases """ try: _, name = utils.get_hostname_ip() except Exception as e: - llm_logger.warning( + scheduler_logger.warning( f"Scheduler encountered an error while resolving the IP address. {e}") name = str(uuid.uuid4()) size = len(name) - now = time.time() - local_time = datetime.fromtimestamp(now) - formatted_time = local_time.strftime( - "%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}" - count = 1 while True: - if self.client.set(self._instance_name(name), formatted_time, ex=self.ttl, nx=True): + if self.client.set(self._instance_name(name), "", ex=self.keep_alive_duration, nx=True): break name = f"{name[:size]}:{count}" count += 1 - return name + + shard = self._get_hash_slot(name) % self.load_shards_num + self.client.set(self._instance_name(name), self._load_table_name(shard=shard), + ex=self.keep_alive_duration) + return name, shard def _keep_alive(self): """ Background thread that periodically updates the scheduler's TTL in Redis. - Runs in a loop with interval of TTL/2 to maintain instance registration. + Runs in a loop with interval of keep_alive_duration/2 to maintain instance registration. """ - interval_time = self.ttl / 2 while True: try: - now = time.time() - local_time = datetime.fromtimestamp(now) - formatted_time = local_time.strftime( - "%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}" - self.client.set(self._instance_name(self.name), - formatted_time, ex=self.ttl) + self.client.set(self._instance_name( + self.name), self._load_table_name(), ex=self.keep_alive_duration) + time.sleep(self.keep_alive_duration / 2) except Exception as e: - llm_logger.error(f"Scheduler keep alive failed: {e}") - interval_time = self.ttl / 10 - - time.sleep(interval_time) - interval_time = self.ttl / 2 + scheduler_logger.error(f"Scheduler keep alive failed: {e}") + time.sleep(min(3, self.keep_alive_duration / 4)) def _scheduler_name_from_request_queue(self, request_queue: str) -> str: """ @@ -243,22 +275,18 @@ class GlobalScheduler(object): return f"{self.topic}.resp.{self.name}" return f"{self.topic}.resp.{scheduler_name}" - def _load_table_name(self, request_queue_name: Optional[str] = None, slot: Optional[int] = None) -> str: + def _load_table_name(self, shard: Optional[int] = None, slot: Optional[int] = None) -> str: """ Get the Redis sorted set name used for load balancing. Returns: The load score key name """ - if request_queue_name is None: - request_queue_name = self._request_queue_name() - - if slot is None: - slot = self._get_hash_slot( - request_queue_name) % self.load_shrads_num - else: - slot %= self.load_shrads_num - return f"{self.topic}.load.{slot}" + if shard is None and slot is not None: + shard = slot % self.load_shards_num + if shard is None: + shard = self.shard + return f"{self.topic}.load.{shard}" @staticmethod def calc_required_blocks(token_num, block_size): @@ -330,11 +358,11 @@ class GlobalScheduler(object): self.client.zincrby(self._load_table_name(), len(serialized_requests), self.name, rem_amount=0, ttl=self.ttl) - llm_logger.info( + scheduler_logger.info( f"Scheduler has enqueued some requests: {requests}") if duplicate: - llm_logger.warning( + scheduler_logger.warning( "Scheduler has received some duplicated requests: " f"{[task for task in tasks if task.reason is not None]}") return tasks @@ -375,7 +403,7 @@ class GlobalScheduler(object): """ if available_blocks <= reserved_output_blocks or batch < 1: - llm_logger.debug( + scheduler_logger.debug( f"Scheduler's resource are insufficient: available_blocks={available_blocks} " f"reserved_output_blocks={reserved_output_blocks} batch={batch} " f"max_num_batched_tokens={max_num_batched_tokens}") @@ -406,15 +434,17 @@ class GlobalScheduler(object): for element in elements] extend_scheduler_names = [] + extend_scheduler_load_table_name = "" if len(serialized_requests) == 0 and len(batches) > 0: - for _ in range(min(5, self.load_shrads_num)): + for _ in range(min(self.load_lookup_num, self.load_shards_num)): + extend_scheduler_load_table_name = self._load_table_name( + slot=self.load_slot_for_getting_request) serialized_members = self.client.zrangebyscore( - self._load_table_name( - slot=self.load_slot_for_getting_request), + extend_scheduler_load_table_name, self.min_load_score, float("+inf"), - start=self.load_start, - num=self.load_num) + start=self.load_offset, + num=self.load_count) self.load_slot_for_getting_request += 1 if len(serialized_members) > 0: break @@ -433,23 +463,18 @@ class GlobalScheduler(object): elements = self.client.lpop(lucky_request_queue_name, batches[0]) if elements is not None and len(elements) > 0: - self.client.zincrby( - self._load_table_name( - request_queue_name=lucky_request_queue_name), - -len(elements), lucky, rem_amount=0, ttl=self.ttl) + self.client.zincrby(extend_scheduler_load_table_name, + -len(elements), lucky, rem_amount=0, ttl=self.ttl) serialized_requests += [(lucky_request_queue_name, element) for element in elements] - llm_logger.info( + scheduler_logger.info( f"Scheduler {self.name} has stolen some requests from another lucky one. " f"(name={lucky} num={len(serialized_requests)})") else: exist_num = self.client.exists(self._instance_name(lucky)) if exist_num == 0: - if self.client.zrem( - self._load_table_name( - request_queue_name=lucky_request_queue_name), - lucky): - llm_logger.info( + if self.client.zrem(extend_scheduler_load_table_name, lucky): + scheduler_logger.info( f"Scheduler {lucky} has been removed") # blocked read @@ -465,12 +490,12 @@ class GlobalScheduler(object): request_queue_name = element[0].decode("utf-8") scheduler_name = self._scheduler_name_from_request_queue( request_queue_name) - self.client.zincrby( - self._load_table_name(request_queue_name=request_queue_name), - -1, scheduler_name, rem_amount=0, ttl=self.ttl) + load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name() + self.client.zincrby(load_table_name, + -1, scheduler_name, rem_amount=0, ttl=self.ttl) serialized_requests.append((request_queue_name, element[1])) if scheduler_name != self.name: - llm_logger.info( + scheduler_logger.info( f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})") long_partial_requests = 0 @@ -526,12 +551,12 @@ class GlobalScheduler(object): if request.request_queue_name == local_request_queue_name: continue - self._mark_request(request) + # self._mark_request(request) if request.request_id not in self.stolen_requests: self.stolen_requests[request.request_id] = request continue - llm_logger.error( + scheduler_logger.error( f"Scheduler has received a duplicate request from others: {request}") requests: List[Request] = [ @@ -548,19 +573,18 @@ class GlobalScheduler(object): serialized_requests) scheduler_name = self._scheduler_name_from_request_queue( request_queue_name) - self.client.zincrby( - self._load_table_name( - request_queue_name=request_queue_name), - len(serialized_requests), scheduler_name, ttl=self.ttl) + load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name() + self.client.zincrby(load_table_name, + len(serialized_requests), scheduler_name, ttl=self.ttl) - llm_logger.info( + scheduler_logger.info( f"Scheduler has put remaining request into the queue: {len(remaining_request)}") if len(requests) == 0: - llm_logger.debug( + scheduler_logger.debug( f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}") if len(requests) > 0: - llm_logger.info( + scheduler_logger.info( f"Scheduler has pulled some request: {[request.request_id for request in requests]}") return requests @@ -600,7 +624,7 @@ class GlobalScheduler(object): if response.request_id in stolen_request_id_request_queue: response_queue_name = stolen_request_id_response_queue[response.request_id] request_queue_name = stolen_request_id_request_queue[response.request_id] - self._unmark_response(response, request_queue_name) + # self._unmark_response(response, request_queue_name) if response_queue_name not in stolen_responses: stolen_responses[response_queue_name] = [] @@ -608,7 +632,7 @@ class GlobalScheduler(object): response.serialize()) continue - llm_logger.error( + scheduler_logger.error( f"Scheduler has recieved a non-existent response from engine: {[response]}") with self.mutex: @@ -624,7 +648,7 @@ class GlobalScheduler(object): self.local_response_not_empty.notify_all() if len(finished_request_ids) > 0: - llm_logger.info( + scheduler_logger.info( f"Scheduler has received some finished responses: {finished_request_ids}") for response_queue_name, responses in stolen_responses.items(): @@ -681,15 +705,15 @@ class GlobalScheduler(object): with self.mutex: for request_id, contents in responses.items(): if request_id not in self.local_responses: - llm_logger.error( + scheduler_logger.error( "Scheduler has received some non-existent response from the queue. " f"response:{contents} queue:{self._response_queue_name()}") continue self.local_responses[request_id] += contents self.local_response_not_empty.notify_all() except Exception as e: - llm_logger.error(f"Scheduler get_results_worker exception: {e} " - f"traceback: {traceback.format_exc()}") + scheduler_logger.error(f"Scheduler get_results_worker exception: {e} " + f"traceback: {traceback.format_exc()}") def get_results(self) -> Dict[str, List[RequestOutput]]: """ @@ -718,7 +742,7 @@ class GlobalScheduler(object): - Thread-safe operation using condition variables - Short timeout avoids blocking while maintaining responsiveness - First call may return empty to batch small responses - - Automatically logs finished requests via llm_logger + - Automatically logs finished requests via scheduler_logger """ first = True @@ -754,7 +778,7 @@ class GlobalScheduler(object): if finished: del self.local_responses[request_id] - llm_logger.info( + scheduler_logger.info( f"Scheduler has pulled a finished response: {[request_id]}") return results @@ -787,4 +811,41 @@ class GlobalScheduler(object): self.client.zrem(self._load_table_name(), self.name) self.local_responses = dict() self.stolen_requests = dict() - llm_logger.info("Scheduler has been reset") + scheduler_logger.info("Scheduler has been reset") + + def update_config(self, load_shards_num: Optional[int], reallocate: Optional[bool]): + """ + Update the scheduler's configuration parameters dynamically. + + This method allows runtime modification of: + - Total number of load balancing shards + - Current instance's shard assignment + + Args: + load_shards_num: New total number of load balancing shards (must be > 0) + reallocate: If True, recalculates this instance's shard assignment + + Effects: + - Updates internal load balancing configuration + - Optionally reallocates this instance to a new shard + - Logs configuration changes for audit purposes + + Note: + - Changes take effect immediately for new operations + - Existing in-progress operations continue with old configuration + - Reallocation may affect request distribution pattern + """ + with self.mutex: + old_load_shards_num = self.load_shards_num + old_shard = self.shard + + if load_shards_num: + self.load_shards_num = load_shards_num + + if reallocate: + self.shard = self._get_hash_slot( + self.name) % self.load_shards_num + + scheduler_logger.info("Scheduler has reload config, " + f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) " + f"shard({old_shard} => {self.shard})") diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index 9dd18172e..8f0f5e8d2 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse -from fastdeploy.utils import llm_logger +from fastdeploy.utils import scheduler_logger class LocalScheduler(object): @@ -115,7 +115,7 @@ class LocalScheduler(object): self.ids = list() self.requests = dict() self.responses = dict() - llm_logger.info("Scheduler has been reset") + scheduler_logger.info("Scheduler has been reset") def _recycle(self, request_id: Optional[str] = None): """ @@ -189,10 +189,10 @@ class LocalScheduler(object): self.ids += valid_ids self.requests_not_empty.notify_all() - llm_logger.info(f"Scheduler has enqueued some requests: {valid_ids}") + scheduler_logger.info(f"Scheduler has enqueued some requests: {valid_ids}") if len(duplicated_ids) > 0: - llm_logger.warning( + scheduler_logger.warning( f"Scheduler has received some duplicated requests: {duplicated_ids}" ) @@ -234,7 +234,7 @@ class LocalScheduler(object): List of Request objects ready for processing """ if available_blocks <= reserved_output_blocks or batch < 1: - llm_logger.debug( + scheduler_logger.debug( f"Scheduler's resource are insufficient: available_blocks={available_blocks} " f"reserved_output_blocks={reserved_output_blocks} batch={batch} " f"max_num_batched_tokens={max_num_batched_tokens}") @@ -277,12 +277,12 @@ class LocalScheduler(object): self.ids_read_cursor += len(requests) if len(batch_ids) > 0 and len(requests) == 0: - llm_logger.debug( + scheduler_logger.debug( f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" ) if len(requests) > 0: - llm_logger.info( + scheduler_logger.info( f"Scheduler has pulled some request: {[request.request_id for request in requests]}" ) @@ -303,14 +303,14 @@ class LocalScheduler(object): response.request_id for response in responses if response.finished ] if len(finished_responses) > 0: - llm_logger.info( + scheduler_logger.info( f"Scheduler has received some finished responses: {finished_responses}" ) with self.mutex: for response in responses: if response.request_id not in self.requests: - llm_logger.warning( + scheduler_logger.warning( f"Scheduler has received a expired response: {[response.request_id]}" ) continue @@ -342,7 +342,7 @@ class LocalScheduler(object): - Thread-safe operation using condition variables - Has a short timeout (0.001s) to avoid blocking - Automatically recycles completed requests to free memory - - Logs finished requests via llm_logger + - Logs finished requests via scheduler_logger """ def _get_results(): @@ -364,7 +364,7 @@ class LocalScheduler(object): if finished: self._recycle(request_id) - llm_logger.info( + scheduler_logger.info( f"Scheduler has pulled a finished response: {[request_id]}" ) return results diff --git a/fastdeploy/scheduler/workers.py b/fastdeploy/scheduler/workers.py index 74b53fc99..64be8945e 100644 --- a/fastdeploy/scheduler/workers.py +++ b/fastdeploy/scheduler/workers.py @@ -18,7 +18,7 @@ from typing import Callable, List, Any, Dict, Optional import functools import threading import traceback -from fastdeploy.utils import llm_logger +from fastdeploy.utils import scheduler_logger class Task: @@ -163,7 +163,7 @@ class Workers: try: results = self.work(tasks) except Exception as e: - llm_logger.error( + scheduler_logger.error( f"Worker {self.name} execute error: {e}, traceback: {traceback.format_exc()}") continue