Global scheduler supports configuring hot updates (#2812)

This commit is contained in:
lddfym
2025-07-11 13:39:30 +08:00
committed by GitHub
parent 94691bcd90
commit ec986642df
6 changed files with 215 additions and 114 deletions

View File

@@ -165,12 +165,6 @@ class LLMEngine(object):
disable_any_whitespace=self.cfg.disable_any_whitespace,
)
def reset_scheduler(self):
"""
Reset the scheduler to its initial state.
"""
self.scheduler.reset()
def start(self, api_server_pid=None):
"""
Initializes the engine and starts its sub-services.

View File

@@ -32,7 +32,8 @@ from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
ErrorResponse)
ErrorResponse,
ControlSchedulerRequest)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import \
OpenAIServingCompletion
@@ -273,10 +274,13 @@ def clear_load_weight(request: Request) -> Response:
status_code=404)
def launch_api_server(args) -> None:
def launch_api_server() -> None:
"""
启动http服务
"""
if not is_port_available(args.host, args.port):
raise Exception(f"The parameter `port`:{args.port} is already in use.")
api_server_logger.info(
f"launch Fastdeploy api server... port: {args.port}")
api_server_logger.info(f"args: {args.__dict__}")
@@ -319,6 +323,11 @@ def run_metrics_server():
def launch_metrics_server():
"""Metrics server running the sub thread"""
if not is_port_available(args.host, args.metrics_port):
raise Exception(
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
)
prom_dir = cleanup_prometheus_files(True)
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
metrics_server_thread = threading.Thread(target=run_metrics_server,
@@ -339,10 +348,39 @@ def reset_scheduler():
if llm_engine is None:
return Response("Engine not loaded", status_code=500)
llm_engine.reset_scheduler()
llm_engine.scheduler.reset_scheduler()
return Response("Scheduler Reset Successfully", status_code=200)
@controller_app.post("/controller/scheduler")
def control_scheduler(request: ControlSchedulerRequest):
"""
Control the scheduler behavior with the given parameters.
"""
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
global llm_engine
if llm_engine is None:
content.message = "Engine is not loaded"
content.code = 500
return JSONResponse(content=content.model_dump(), status_code=500)
if request.reset:
llm_engine.scheduler.reset_scheduler()
if request.load_shards_num or request.reallocate_shard:
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
llm_engine.scheduler.update_config(
load_shards_num=request.load_shards_num,
reallocate=request.reallocate_shard)
else:
content.message="This scheduler doesn't support the `update_config()` method."
content.code=400
return JSONResponse(content=content.model_dump(), status_code=400)
return JSONResponse(content=content.model_dump(), status_code=200)
def run_controller_server():
"""
run controller server
@@ -358,6 +396,11 @@ def launch_controller_server():
if args.controller_port < 0:
return
if not is_port_available(args.host, args.controller_port):
raise Exception(
f"The parameter `controller_port`:{args.controller_port} is already in use."
)
controller_server_thread = threading.Thread(target=run_controller_server,
daemon=True)
controller_server_thread.start()
@@ -366,19 +409,13 @@ def launch_controller_server():
def main():
"""main函数"""
if not is_port_available(args.host, args.port):
raise Exception(f"The parameter `port`:{args.port} is already in use.")
if not is_port_available(args.host, args.metrics_port):
raise Exception(
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
)
if load_engine() is None:
return
launch_controller_server()
launch_metrics_server()
launch_api_server(args)
launch_api_server()
if __name__ == "__main__":

View File

@@ -542,3 +542,12 @@ class ChatCompletionRequest(BaseModel):
)
return data
class ControlSchedulerRequest(BaseModel):
"""
Control scheduler request to the engine.
"""
reset: Optional[bool] = False
load_shards_num: Optional[int] = None
reallocate_shard: Optional[bool] = False

View File

@@ -19,7 +19,6 @@ from typing import List, Optional, Dict, Tuple
import traceback
import threading
import time
from datetime import datetime
import random
import uuid
import crcmod
@@ -28,7 +27,7 @@ from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.workers import Workers, Task
from fastdeploy.utils import llm_logger
from fastdeploy.utils import scheduler_logger
from fastdeploy.scheduler import utils
@@ -51,7 +50,7 @@ class GlobalScheduler(object):
topic: str,
ttl: int,
min_load_score: float,
load_shrads_num: int,
load_shards_num: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
@@ -68,7 +67,7 @@ class GlobalScheduler(object):
topic: Base topic name for queue namespacing
ttl: Time-to-live in seconds for Redis keys
min_load_score: Minimum load score for task assignment
load_shrads_num: Number of shards for load balancing table
load_shards_num: Number of shards for load balancing table
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Maximum number of partial prefills allowed
max_long_partial_prefills: Maximum number of long partial prefills allowed
@@ -84,7 +83,7 @@ class GlobalScheduler(object):
self.topic = topic
self.ttl = ttl
self.min_load_score = min_load_score
self.load_shrads_num = load_shrads_num
self.load_shards_num = load_shards_num
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
@@ -97,14 +96,17 @@ class GlobalScheduler(object):
self.crc16_mutex = threading.Lock()
self.crc16 = crcmod.predefined.Crc('ccitt-false')
self.load_slot_for_getting_request = 0
self.load_start = 0 # const
self.load_num = 50 # const
self.load_offset = 0 # const
self.load_count = 50 # const
self.load_lookup_num = 5 # const
self.keep_alive_duration = 30 # const
connection_pool = ConnectionPool(
host=host, port=port, db=db, password=password, max_connections=10)
self.client = AdaptedRedis(connection_pool=connection_pool)
self.name = self._generate_scheduler_name()
self.name, self.shard = self._generate_scheduler_name_and_shard()
self.keep_alive_workers = threading.Thread(
target=self._keep_alive, daemon=True)
self.keep_alive_workers.start()
@@ -126,10 +128,32 @@ class GlobalScheduler(object):
target=self._get_results_worker, daemon=True)
self.get_response_workers.start()
llm_logger.info(
scheduler_logger.info(
f"Scheduler: name={self.name} redis_version={self.client.version}")
def _get_hash_slot(self, data: str) -> int:
"""
Calculate the hash slot for a given string using CRC16 algorithm.
This method is thread-safe and used for consistent hashing in distributed scheduling.
It implements the same CRC16 algorithm (CCITT-FALSE variant) used by Redis Cluster.
Args:
data: Input string to be hashed (typically a scheduler or request identifier)
Returns:
int: A 16-bit hash value (0-65535) representing the calculated slot
Implementation Details:
1. Encodes input string as UTF-8 bytes
2. Uses thread-safe CRC16 calculation with mutex protection
3. Resets CRC state after each calculation
4. Returns raw CRC value without modulo operation
Note:
- The result is typically used with modulo operation for sharding (e.g. % num_shards)
- Matches Redis Cluster's slot distribution algorithm for compatibility
"""
data = data.encode("utf-8")
with self.crc16_mutex:
self.crc16.update(data)
@@ -149,58 +173,66 @@ class GlobalScheduler(object):
"""
return f"{self.topic}.ins.{scheduler_name}"
def _generate_scheduler_name(self) -> str:
def _generate_scheduler_name_and_shard(self) -> Tuple[str, int]:
"""
Generate a unique name for this scheduler instance.
Generate a unique scheduler name and calculate its shard assignment.
Uses hostname/IP and timestamp to create a unique identifier,
then registers it in Redis with TTL.
This method:
1. Creates a unique identifier using hostname/IP and timestamp
2. Registers the name in Redis with TTL
3. Calculates the shard assignment using consistent hashing
4. Handles naming conflicts by appending incrementing suffixes
Returns:
Unique scheduler name string
Tuple[str, int]:
- str: Unique scheduler name
- int: Assigned shard number (0 to load_shards_num-1)
Implementation Details:
- Uses hostname/IP as base identifier, falls back to UUID if unavailable
- Implements conflict resolution with incrementing suffixes
- Registers name in Redis with keep-alive duration
- Calculates shard using CRC16 hash of the name
Error Handling:
- Logs IP resolution failures
- Handles Redis registration conflicts gracefully
- Ensures unique name generation even in edge cases
"""
try:
_, name = utils.get_hostname_ip()
except Exception as e:
llm_logger.warning(
scheduler_logger.warning(
f"Scheduler encountered an error while resolving the IP address. {e}")
name = str(uuid.uuid4())
size = len(name)
now = time.time()
local_time = datetime.fromtimestamp(now)
formatted_time = local_time.strftime(
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
count = 1
while True:
if self.client.set(self._instance_name(name), formatted_time, ex=self.ttl, nx=True):
if self.client.set(self._instance_name(name), "", ex=self.keep_alive_duration, nx=True):
break
name = f"{name[:size]}:{count}"
count += 1
return name
shard = self._get_hash_slot(name) % self.load_shards_num
self.client.set(self._instance_name(name), self._load_table_name(shard=shard),
ex=self.keep_alive_duration)
return name, shard
def _keep_alive(self):
"""
Background thread that periodically updates the scheduler's TTL in Redis.
Runs in a loop with interval of TTL/2 to maintain instance registration.
Runs in a loop with interval of keep_alive_duration/2 to maintain instance registration.
"""
interval_time = self.ttl / 2
while True:
try:
now = time.time()
local_time = datetime.fromtimestamp(now)
formatted_time = local_time.strftime(
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
self.client.set(self._instance_name(self.name),
formatted_time, ex=self.ttl)
self.client.set(self._instance_name(
self.name), self._load_table_name(), ex=self.keep_alive_duration)
time.sleep(self.keep_alive_duration / 2)
except Exception as e:
llm_logger.error(f"Scheduler keep alive failed: {e}")
interval_time = self.ttl / 10
time.sleep(interval_time)
interval_time = self.ttl / 2
scheduler_logger.error(f"Scheduler keep alive failed: {e}")
time.sleep(min(3, self.keep_alive_duration / 4))
def _scheduler_name_from_request_queue(self, request_queue: str) -> str:
"""
@@ -243,22 +275,18 @@ class GlobalScheduler(object):
return f"{self.topic}.resp.{self.name}"
return f"{self.topic}.resp.{scheduler_name}"
def _load_table_name(self, request_queue_name: Optional[str] = None, slot: Optional[int] = None) -> str:
def _load_table_name(self, shard: Optional[int] = None, slot: Optional[int] = None) -> str:
"""
Get the Redis sorted set name used for load balancing.
Returns:
The load score key name
"""
if request_queue_name is None:
request_queue_name = self._request_queue_name()
if slot is None:
slot = self._get_hash_slot(
request_queue_name) % self.load_shrads_num
else:
slot %= self.load_shrads_num
return f"{self.topic}.load.{slot}"
if shard is None and slot is not None:
shard = slot % self.load_shards_num
if shard is None:
shard = self.shard
return f"{self.topic}.load.{shard}"
@staticmethod
def calc_required_blocks(token_num, block_size):
@@ -330,11 +358,11 @@ class GlobalScheduler(object):
self.client.zincrby(self._load_table_name(),
len(serialized_requests), self.name,
rem_amount=0, ttl=self.ttl)
llm_logger.info(
scheduler_logger.info(
f"Scheduler has enqueued some requests: {requests}")
if duplicate:
llm_logger.warning(
scheduler_logger.warning(
"Scheduler has received some duplicated requests: "
f"{[task for task in tasks if task.reason is not None]}")
return tasks
@@ -375,7 +403,7 @@ class GlobalScheduler(object):
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
@@ -406,15 +434,17 @@ class GlobalScheduler(object):
for element in elements]
extend_scheduler_names = []
extend_scheduler_load_table_name = ""
if len(serialized_requests) == 0 and len(batches) > 0:
for _ in range(min(5, self.load_shrads_num)):
for _ in range(min(self.load_lookup_num, self.load_shards_num)):
extend_scheduler_load_table_name = self._load_table_name(
slot=self.load_slot_for_getting_request)
serialized_members = self.client.zrangebyscore(
self._load_table_name(
slot=self.load_slot_for_getting_request),
extend_scheduler_load_table_name,
self.min_load_score,
float("+inf"),
start=self.load_start,
num=self.load_num)
start=self.load_offset,
num=self.load_count)
self.load_slot_for_getting_request += 1
if len(serialized_members) > 0:
break
@@ -433,23 +463,18 @@ class GlobalScheduler(object):
elements = self.client.lpop(lucky_request_queue_name, batches[0])
if elements is not None and len(elements) > 0:
self.client.zincrby(
self._load_table_name(
request_queue_name=lucky_request_queue_name),
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
self.client.zincrby(extend_scheduler_load_table_name,
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
serialized_requests += [(lucky_request_queue_name, element)
for element in elements]
llm_logger.info(
scheduler_logger.info(
f"Scheduler {self.name} has stolen some requests from another lucky one. "
f"(name={lucky} num={len(serialized_requests)})")
else:
exist_num = self.client.exists(self._instance_name(lucky))
if exist_num == 0:
if self.client.zrem(
self._load_table_name(
request_queue_name=lucky_request_queue_name),
lucky):
llm_logger.info(
if self.client.zrem(extend_scheduler_load_table_name, lucky):
scheduler_logger.info(
f"Scheduler {lucky} has been removed")
# blocked read
@@ -465,12 +490,12 @@ class GlobalScheduler(object):
request_queue_name = element[0].decode("utf-8")
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
self.client.zincrby(
self._load_table_name(request_queue_name=request_queue_name),
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
self.client.zincrby(load_table_name,
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
serialized_requests.append((request_queue_name, element[1]))
if scheduler_name != self.name:
llm_logger.info(
scheduler_logger.info(
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
long_partial_requests = 0
@@ -526,12 +551,12 @@ class GlobalScheduler(object):
if request.request_queue_name == local_request_queue_name:
continue
self._mark_request(request)
# self._mark_request(request)
if request.request_id not in self.stolen_requests:
self.stolen_requests[request.request_id] = request
continue
llm_logger.error(
scheduler_logger.error(
f"Scheduler has received a duplicate request from others: {request}")
requests: List[Request] = [
@@ -548,19 +573,18 @@ class GlobalScheduler(object):
serialized_requests)
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
self.client.zincrby(
self._load_table_name(
request_queue_name=request_queue_name),
len(serialized_requests), scheduler_name, ttl=self.ttl)
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
self.client.zincrby(load_table_name,
len(serialized_requests), scheduler_name, ttl=self.ttl)
llm_logger.info(
scheduler_logger.info(
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
if len(requests) == 0:
llm_logger.debug(
scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
if len(requests) > 0:
llm_logger.info(
scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
return requests
@@ -600,7 +624,7 @@ class GlobalScheduler(object):
if response.request_id in stolen_request_id_request_queue:
response_queue_name = stolen_request_id_response_queue[response.request_id]
request_queue_name = stolen_request_id_request_queue[response.request_id]
self._unmark_response(response, request_queue_name)
# self._unmark_response(response, request_queue_name)
if response_queue_name not in stolen_responses:
stolen_responses[response_queue_name] = []
@@ -608,7 +632,7 @@ class GlobalScheduler(object):
response.serialize())
continue
llm_logger.error(
scheduler_logger.error(
f"Scheduler has recieved a non-existent response from engine: {[response]}")
with self.mutex:
@@ -624,7 +648,7 @@ class GlobalScheduler(object):
self.local_response_not_empty.notify_all()
if len(finished_request_ids) > 0:
llm_logger.info(
scheduler_logger.info(
f"Scheduler has received some finished responses: {finished_request_ids}")
for response_queue_name, responses in stolen_responses.items():
@@ -681,15 +705,15 @@ class GlobalScheduler(object):
with self.mutex:
for request_id, contents in responses.items():
if request_id not in self.local_responses:
llm_logger.error(
scheduler_logger.error(
"Scheduler has received some non-existent response from the queue. "
f"response:{contents} queue:{self._response_queue_name()}")
continue
self.local_responses[request_id] += contents
self.local_response_not_empty.notify_all()
except Exception as e:
llm_logger.error(f"Scheduler get_results_worker exception: {e} "
f"traceback: {traceback.format_exc()}")
scheduler_logger.error(f"Scheduler get_results_worker exception: {e} "
f"traceback: {traceback.format_exc()}")
def get_results(self) -> Dict[str, List[RequestOutput]]:
"""
@@ -718,7 +742,7 @@ class GlobalScheduler(object):
- Thread-safe operation using condition variables
- Short timeout avoids blocking while maintaining responsiveness
- First call may return empty to batch small responses
- Automatically logs finished requests via llm_logger
- Automatically logs finished requests via scheduler_logger
"""
first = True
@@ -754,7 +778,7 @@ class GlobalScheduler(object):
if finished:
del self.local_responses[request_id]
llm_logger.info(
scheduler_logger.info(
f"Scheduler has pulled a finished response: {[request_id]}")
return results
@@ -787,4 +811,41 @@ class GlobalScheduler(object):
self.client.zrem(self._load_table_name(), self.name)
self.local_responses = dict()
self.stolen_requests = dict()
llm_logger.info("Scheduler has been reset")
scheduler_logger.info("Scheduler has been reset")
def update_config(self, load_shards_num: Optional[int], reallocate: Optional[bool]):
"""
Update the scheduler's configuration parameters dynamically.
This method allows runtime modification of:
- Total number of load balancing shards
- Current instance's shard assignment
Args:
load_shards_num: New total number of load balancing shards (must be > 0)
reallocate: If True, recalculates this instance's shard assignment
Effects:
- Updates internal load balancing configuration
- Optionally reallocates this instance to a new shard
- Logs configuration changes for audit purposes
Note:
- Changes take effect immediately for new operations
- Existing in-progress operations continue with old configuration
- Reallocation may affect request distribution pattern
"""
with self.mutex:
old_load_shards_num = self.load_shards_num
old_shard = self.shard
if load_shards_num:
self.load_shards_num = load_shards_num
if reallocate:
self.shard = self._get_hash_slot(
self.name) % self.load_shards_num
scheduler_logger.info("Scheduler has reload config, "
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
f"shard({old_shard} => {self.shard})")

View File

@@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.utils import llm_logger
from fastdeploy.utils import scheduler_logger
class LocalScheduler(object):
@@ -115,7 +115,7 @@ class LocalScheduler(object):
self.ids = list()
self.requests = dict()
self.responses = dict()
llm_logger.info("Scheduler has been reset")
scheduler_logger.info("Scheduler has been reset")
def _recycle(self, request_id: Optional[str] = None):
"""
@@ -189,10 +189,10 @@ class LocalScheduler(object):
self.ids += valid_ids
self.requests_not_empty.notify_all()
llm_logger.info(f"Scheduler has enqueued some requests: {valid_ids}")
scheduler_logger.info(f"Scheduler has enqueued some requests: {valid_ids}")
if len(duplicated_ids) > 0:
llm_logger.warning(
scheduler_logger.warning(
f"Scheduler has received some duplicated requests: {duplicated_ids}"
)
@@ -234,7 +234,7 @@ class LocalScheduler(object):
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
@@ -277,12 +277,12 @@ class LocalScheduler(object):
self.ids_read_cursor += len(requests)
if len(batch_ids) > 0 and len(requests) == 0:
llm_logger.debug(
scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
llm_logger.info(
scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
@@ -303,14 +303,14 @@ class LocalScheduler(object):
response.request_id for response in responses if response.finished
]
if len(finished_responses) > 0:
llm_logger.info(
scheduler_logger.info(
f"Scheduler has received some finished responses: {finished_responses}"
)
with self.mutex:
for response in responses:
if response.request_id not in self.requests:
llm_logger.warning(
scheduler_logger.warning(
f"Scheduler has received a expired response: {[response.request_id]}"
)
continue
@@ -342,7 +342,7 @@ class LocalScheduler(object):
- Thread-safe operation using condition variables
- Has a short timeout (0.001s) to avoid blocking
- Automatically recycles completed requests to free memory
- Logs finished requests via llm_logger
- Logs finished requests via scheduler_logger
"""
def _get_results():
@@ -364,7 +364,7 @@ class LocalScheduler(object):
if finished:
self._recycle(request_id)
llm_logger.info(
scheduler_logger.info(
f"Scheduler has pulled a finished response: {[request_id]}"
)
return results

View File

@@ -18,7 +18,7 @@ from typing import Callable, List, Any, Dict, Optional
import functools
import threading
import traceback
from fastdeploy.utils import llm_logger
from fastdeploy.utils import scheduler_logger
class Task:
@@ -163,7 +163,7 @@ class Workers:
try:
results = self.work(tasks)
except Exception as e:
llm_logger.error(
scheduler_logger.error(
f"Worker {self.name} execute error: {e}, traceback: {traceback.format_exc()}")
continue