Global scheduler supports configuring hot updates (#2812)

This commit is contained in:
lddfym
2025-07-11 13:39:30 +08:00
committed by GitHub
parent 94691bcd90
commit ec986642df
6 changed files with 215 additions and 114 deletions

View File

@@ -19,7 +19,6 @@ from typing import List, Optional, Dict, Tuple
import traceback
import threading
import time
from datetime import datetime
import random
import uuid
import crcmod
@@ -28,7 +27,7 @@ from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.workers import Workers, Task
from fastdeploy.utils import llm_logger
from fastdeploy.utils import scheduler_logger
from fastdeploy.scheduler import utils
@@ -51,7 +50,7 @@ class GlobalScheduler(object):
topic: str,
ttl: int,
min_load_score: float,
load_shrads_num: int,
load_shards_num: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
@@ -68,7 +67,7 @@ class GlobalScheduler(object):
topic: Base topic name for queue namespacing
ttl: Time-to-live in seconds for Redis keys
min_load_score: Minimum load score for task assignment
load_shrads_num: Number of shards for load balancing table
load_shards_num: Number of shards for load balancing table
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Maximum number of partial prefills allowed
max_long_partial_prefills: Maximum number of long partial prefills allowed
@@ -84,7 +83,7 @@ class GlobalScheduler(object):
self.topic = topic
self.ttl = ttl
self.min_load_score = min_load_score
self.load_shrads_num = load_shrads_num
self.load_shards_num = load_shards_num
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
@@ -97,14 +96,17 @@ class GlobalScheduler(object):
self.crc16_mutex = threading.Lock()
self.crc16 = crcmod.predefined.Crc('ccitt-false')
self.load_slot_for_getting_request = 0
self.load_start = 0 # const
self.load_num = 50 # const
self.load_offset = 0 # const
self.load_count = 50 # const
self.load_lookup_num = 5 # const
self.keep_alive_duration = 30 # const
connection_pool = ConnectionPool(
host=host, port=port, db=db, password=password, max_connections=10)
self.client = AdaptedRedis(connection_pool=connection_pool)
self.name = self._generate_scheduler_name()
self.name, self.shard = self._generate_scheduler_name_and_shard()
self.keep_alive_workers = threading.Thread(
target=self._keep_alive, daemon=True)
self.keep_alive_workers.start()
@@ -126,10 +128,32 @@ class GlobalScheduler(object):
target=self._get_results_worker, daemon=True)
self.get_response_workers.start()
llm_logger.info(
scheduler_logger.info(
f"Scheduler: name={self.name} redis_version={self.client.version}")
def _get_hash_slot(self, data: str) -> int:
"""
Calculate the hash slot for a given string using CRC16 algorithm.
This method is thread-safe and used for consistent hashing in distributed scheduling.
It implements the same CRC16 algorithm (CCITT-FALSE variant) used by Redis Cluster.
Args:
data: Input string to be hashed (typically a scheduler or request identifier)
Returns:
int: A 16-bit hash value (0-65535) representing the calculated slot
Implementation Details:
1. Encodes input string as UTF-8 bytes
2. Uses thread-safe CRC16 calculation with mutex protection
3. Resets CRC state after each calculation
4. Returns raw CRC value without modulo operation
Note:
- The result is typically used with modulo operation for sharding (e.g. % num_shards)
- Matches Redis Cluster's slot distribution algorithm for compatibility
"""
data = data.encode("utf-8")
with self.crc16_mutex:
self.crc16.update(data)
@@ -149,58 +173,66 @@ class GlobalScheduler(object):
"""
return f"{self.topic}.ins.{scheduler_name}"
def _generate_scheduler_name(self) -> str:
def _generate_scheduler_name_and_shard(self) -> Tuple[str, int]:
"""
Generate a unique name for this scheduler instance.
Generate a unique scheduler name and calculate its shard assignment.
Uses hostname/IP and timestamp to create a unique identifier,
then registers it in Redis with TTL.
This method:
1. Creates a unique identifier using hostname/IP and timestamp
2. Registers the name in Redis with TTL
3. Calculates the shard assignment using consistent hashing
4. Handles naming conflicts by appending incrementing suffixes
Returns:
Unique scheduler name string
Tuple[str, int]:
- str: Unique scheduler name
- int: Assigned shard number (0 to load_shards_num-1)
Implementation Details:
- Uses hostname/IP as base identifier, falls back to UUID if unavailable
- Implements conflict resolution with incrementing suffixes
- Registers name in Redis with keep-alive duration
- Calculates shard using CRC16 hash of the name
Error Handling:
- Logs IP resolution failures
- Handles Redis registration conflicts gracefully
- Ensures unique name generation even in edge cases
"""
try:
_, name = utils.get_hostname_ip()
except Exception as e:
llm_logger.warning(
scheduler_logger.warning(
f"Scheduler encountered an error while resolving the IP address. {e}")
name = str(uuid.uuid4())
size = len(name)
now = time.time()
local_time = datetime.fromtimestamp(now)
formatted_time = local_time.strftime(
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
count = 1
while True:
if self.client.set(self._instance_name(name), formatted_time, ex=self.ttl, nx=True):
if self.client.set(self._instance_name(name), "", ex=self.keep_alive_duration, nx=True):
break
name = f"{name[:size]}:{count}"
count += 1
return name
shard = self._get_hash_slot(name) % self.load_shards_num
self.client.set(self._instance_name(name), self._load_table_name(shard=shard),
ex=self.keep_alive_duration)
return name, shard
def _keep_alive(self):
"""
Background thread that periodically updates the scheduler's TTL in Redis.
Runs in a loop with interval of TTL/2 to maintain instance registration.
Runs in a loop with interval of keep_alive_duration/2 to maintain instance registration.
"""
interval_time = self.ttl / 2
while True:
try:
now = time.time()
local_time = datetime.fromtimestamp(now)
formatted_time = local_time.strftime(
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
self.client.set(self._instance_name(self.name),
formatted_time, ex=self.ttl)
self.client.set(self._instance_name(
self.name), self._load_table_name(), ex=self.keep_alive_duration)
time.sleep(self.keep_alive_duration / 2)
except Exception as e:
llm_logger.error(f"Scheduler keep alive failed: {e}")
interval_time = self.ttl / 10
time.sleep(interval_time)
interval_time = self.ttl / 2
scheduler_logger.error(f"Scheduler keep alive failed: {e}")
time.sleep(min(3, self.keep_alive_duration / 4))
def _scheduler_name_from_request_queue(self, request_queue: str) -> str:
"""
@@ -243,22 +275,18 @@ class GlobalScheduler(object):
return f"{self.topic}.resp.{self.name}"
return f"{self.topic}.resp.{scheduler_name}"
def _load_table_name(self, request_queue_name: Optional[str] = None, slot: Optional[int] = None) -> str:
def _load_table_name(self, shard: Optional[int] = None, slot: Optional[int] = None) -> str:
"""
Get the Redis sorted set name used for load balancing.
Returns:
The load score key name
"""
if request_queue_name is None:
request_queue_name = self._request_queue_name()
if slot is None:
slot = self._get_hash_slot(
request_queue_name) % self.load_shrads_num
else:
slot %= self.load_shrads_num
return f"{self.topic}.load.{slot}"
if shard is None and slot is not None:
shard = slot % self.load_shards_num
if shard is None:
shard = self.shard
return f"{self.topic}.load.{shard}"
@staticmethod
def calc_required_blocks(token_num, block_size):
@@ -330,11 +358,11 @@ class GlobalScheduler(object):
self.client.zincrby(self._load_table_name(),
len(serialized_requests), self.name,
rem_amount=0, ttl=self.ttl)
llm_logger.info(
scheduler_logger.info(
f"Scheduler has enqueued some requests: {requests}")
if duplicate:
llm_logger.warning(
scheduler_logger.warning(
"Scheduler has received some duplicated requests: "
f"{[task for task in tasks if task.reason is not None]}")
return tasks
@@ -375,7 +403,7 @@ class GlobalScheduler(object):
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
@@ -406,15 +434,17 @@ class GlobalScheduler(object):
for element in elements]
extend_scheduler_names = []
extend_scheduler_load_table_name = ""
if len(serialized_requests) == 0 and len(batches) > 0:
for _ in range(min(5, self.load_shrads_num)):
for _ in range(min(self.load_lookup_num, self.load_shards_num)):
extend_scheduler_load_table_name = self._load_table_name(
slot=self.load_slot_for_getting_request)
serialized_members = self.client.zrangebyscore(
self._load_table_name(
slot=self.load_slot_for_getting_request),
extend_scheduler_load_table_name,
self.min_load_score,
float("+inf"),
start=self.load_start,
num=self.load_num)
start=self.load_offset,
num=self.load_count)
self.load_slot_for_getting_request += 1
if len(serialized_members) > 0:
break
@@ -433,23 +463,18 @@ class GlobalScheduler(object):
elements = self.client.lpop(lucky_request_queue_name, batches[0])
if elements is not None and len(elements) > 0:
self.client.zincrby(
self._load_table_name(
request_queue_name=lucky_request_queue_name),
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
self.client.zincrby(extend_scheduler_load_table_name,
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
serialized_requests += [(lucky_request_queue_name, element)
for element in elements]
llm_logger.info(
scheduler_logger.info(
f"Scheduler {self.name} has stolen some requests from another lucky one. "
f"(name={lucky} num={len(serialized_requests)})")
else:
exist_num = self.client.exists(self._instance_name(lucky))
if exist_num == 0:
if self.client.zrem(
self._load_table_name(
request_queue_name=lucky_request_queue_name),
lucky):
llm_logger.info(
if self.client.zrem(extend_scheduler_load_table_name, lucky):
scheduler_logger.info(
f"Scheduler {lucky} has been removed")
# blocked read
@@ -465,12 +490,12 @@ class GlobalScheduler(object):
request_queue_name = element[0].decode("utf-8")
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
self.client.zincrby(
self._load_table_name(request_queue_name=request_queue_name),
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
self.client.zincrby(load_table_name,
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
serialized_requests.append((request_queue_name, element[1]))
if scheduler_name != self.name:
llm_logger.info(
scheduler_logger.info(
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
long_partial_requests = 0
@@ -526,12 +551,12 @@ class GlobalScheduler(object):
if request.request_queue_name == local_request_queue_name:
continue
self._mark_request(request)
# self._mark_request(request)
if request.request_id not in self.stolen_requests:
self.stolen_requests[request.request_id] = request
continue
llm_logger.error(
scheduler_logger.error(
f"Scheduler has received a duplicate request from others: {request}")
requests: List[Request] = [
@@ -548,19 +573,18 @@ class GlobalScheduler(object):
serialized_requests)
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
self.client.zincrby(
self._load_table_name(
request_queue_name=request_queue_name),
len(serialized_requests), scheduler_name, ttl=self.ttl)
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
self.client.zincrby(load_table_name,
len(serialized_requests), scheduler_name, ttl=self.ttl)
llm_logger.info(
scheduler_logger.info(
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
if len(requests) == 0:
llm_logger.debug(
scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
if len(requests) > 0:
llm_logger.info(
scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
return requests
@@ -600,7 +624,7 @@ class GlobalScheduler(object):
if response.request_id in stolen_request_id_request_queue:
response_queue_name = stolen_request_id_response_queue[response.request_id]
request_queue_name = stolen_request_id_request_queue[response.request_id]
self._unmark_response(response, request_queue_name)
# self._unmark_response(response, request_queue_name)
if response_queue_name not in stolen_responses:
stolen_responses[response_queue_name] = []
@@ -608,7 +632,7 @@ class GlobalScheduler(object):
response.serialize())
continue
llm_logger.error(
scheduler_logger.error(
f"Scheduler has recieved a non-existent response from engine: {[response]}")
with self.mutex:
@@ -624,7 +648,7 @@ class GlobalScheduler(object):
self.local_response_not_empty.notify_all()
if len(finished_request_ids) > 0:
llm_logger.info(
scheduler_logger.info(
f"Scheduler has received some finished responses: {finished_request_ids}")
for response_queue_name, responses in stolen_responses.items():
@@ -681,15 +705,15 @@ class GlobalScheduler(object):
with self.mutex:
for request_id, contents in responses.items():
if request_id not in self.local_responses:
llm_logger.error(
scheduler_logger.error(
"Scheduler has received some non-existent response from the queue. "
f"response:{contents} queue:{self._response_queue_name()}")
continue
self.local_responses[request_id] += contents
self.local_response_not_empty.notify_all()
except Exception as e:
llm_logger.error(f"Scheduler get_results_worker exception: {e} "
f"traceback: {traceback.format_exc()}")
scheduler_logger.error(f"Scheduler get_results_worker exception: {e} "
f"traceback: {traceback.format_exc()}")
def get_results(self) -> Dict[str, List[RequestOutput]]:
"""
@@ -718,7 +742,7 @@ class GlobalScheduler(object):
- Thread-safe operation using condition variables
- Short timeout avoids blocking while maintaining responsiveness
- First call may return empty to batch small responses
- Automatically logs finished requests via llm_logger
- Automatically logs finished requests via scheduler_logger
"""
first = True
@@ -754,7 +778,7 @@ class GlobalScheduler(object):
if finished:
del self.local_responses[request_id]
llm_logger.info(
scheduler_logger.info(
f"Scheduler has pulled a finished response: {[request_id]}")
return results
@@ -787,4 +811,41 @@ class GlobalScheduler(object):
self.client.zrem(self._load_table_name(), self.name)
self.local_responses = dict()
self.stolen_requests = dict()
llm_logger.info("Scheduler has been reset")
scheduler_logger.info("Scheduler has been reset")
def update_config(self, load_shards_num: Optional[int], reallocate: Optional[bool]):
"""
Update the scheduler's configuration parameters dynamically.
This method allows runtime modification of:
- Total number of load balancing shards
- Current instance's shard assignment
Args:
load_shards_num: New total number of load balancing shards (must be > 0)
reallocate: If True, recalculates this instance's shard assignment
Effects:
- Updates internal load balancing configuration
- Optionally reallocates this instance to a new shard
- Logs configuration changes for audit purposes
Note:
- Changes take effect immediately for new operations
- Existing in-progress operations continue with old configuration
- Reallocation may affect request distribution pattern
"""
with self.mutex:
old_load_shards_num = self.load_shards_num
old_shard = self.shard
if load_shards_num:
self.load_shards_num = load_shards_num
if reallocate:
self.shard = self._get_hash_slot(
self.name) % self.load_shards_num
scheduler_logger.info("Scheduler has reload config, "
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
f"shard({old_shard} => {self.shard})")