Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -16,19 +16,31 @@
from typing import List, Optional, Dict, Tuple
import traceback
import threading
import time
from datetime import datetime
import random
import uuid
import crcmod
from redis import ConnectionPool
from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.workers import Workers
from fastdeploy.scheduler.workers import Workers, Task
from fastdeploy.utils import llm_logger
from fastdeploy.scheduler import utils
class GlobalScheduler(object):
"""
GlobalScheduler class
A distributed task scheduler that manages request/response queues using Redis.
This class provides functionality for:
- Enqueuing and dequeuing requests
- Load balancing across multiple scheduler instances
- Handling request/response lifecycle
- Maintaining worker health checks
"""
def __init__(self,
@@ -38,96 +50,328 @@ class GlobalScheduler(object):
password: Optional[str],
topic: str,
ttl: int,
remote_write_time: int,
wait_response_timeout: float
min_load_score: float,
load_shrads_num: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
):
"""
Initialize the GlobalScheduler with Redis connection and scheduling parameters.
Args:
host: Redis server hostname
port: Redis server port
db: Redis database number
password: Optional password for Redis authentication
topic: Base topic name for queue namespacing
ttl: Time-to-live in seconds for Redis keys
min_load_score: Minimum load score for task assignment
load_shrads_num: Number of shards for load balancing table
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Maximum number of partial prefills allowed
max_long_partial_prefills: Maximum number of long partial prefills allowed
long_prefill_token_threshold: Token count threshold for long prefills
Initializes:
- Redis connection pool and client
- Worker threads for request/response handling
- Load balancing and request stealing mechanisms
- Response tracking structures
"""
self.topic = topic
self.ttl = ttl
self.remote_write_time = remote_write_time
self.wait_response_timeout = 1.0 if wait_response_timeout < 1.0 else wait_response_timeout
self.wait_request_timeout = 10
self.min_load_score = min_load_score
self.load_shrads_num = load_shrads_num
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.blpop_request_timeout = 2
self.blpop_response_timeout = 10
self.crc16_mutex = threading.Lock()
self.crc16 = crcmod.predefined.Crc('ccitt-false')
self.load_slot_for_getting_request = 0
self.load_start = 0 # const
self.load_num = 50 # const
connection_pool = ConnectionPool(
host=host, port=port, db=db, password=password, max_connections=10)
self.client = AdaptedRedis(connection_pool=connection_pool)
self.put_request_workers = Workers(
"put_request_worker", self._put_requests_worker, max_batch_size=5)
self.put_request_workers.start(size=1)
self.name = self._generate_scheduler_name()
self.keep_alive_workers = threading.Thread(
target=self._keep_alive, daemon=True)
self.keep_alive_workers.start()
self.put_response_workers = Workers(
"put_response_worker", self._put_results_worker, max_batch_size=50)
self.put_response_workers.start(size=1)
self.put_requests_workers = Workers(
"put_requests_workers", self._put_requests_worker, 20)
self.put_requests_workers.start(1)
self.get_response_workers = Workers(
"get_response_worker", self._get_results_worker, max_batch_size=1)
self.get_response_workers.start(size=5)
self.response_max_batch = 50
self.put_results_workers = Workers(
"put_results_workers", self._put_results_worker, 300)
self.put_results_workers.start(1)
llm_logger.info(f"Scheduler: redis version is {self.client.version}")
self.mutex = threading.Lock()
self.local_response_not_empty = threading.Condition(self.mutex)
self.local_responses: Dict[str, List[ScheduledResponse]] = dict()
self.stolen_requests: Dict[str, ScheduledRequest] = dict()
def _request_queue_name(self):
return f"{self.topic}.request"
self.get_response_workers = threading.Thread(
target=self._get_results_worker, daemon=True)
self.get_response_workers.start()
def _response_queue_name(self, id: str):
return f"{self.topic}.response.{id}"
llm_logger.info(
f"Scheduler: name={self.name} redis_version={self.client.version}")
def _unique_key_name(self, id: str):
return f"{self.topic}.unique.{id}"
def _get_hash_slot(self, data: str) -> int:
data = data.encode("utf-8")
with self.crc16_mutex:
self.crc16.update(data)
crc_value = self.crc16.crcValue
self.crc16.crcValue = self.crc16.initCrc
return crc_value
def _instance_name(self, scheduler_name: str) -> str:
"""
Generate the Redis key name for a scheduler instance.
Args:
scheduler_name: Name of the scheduler instance
Returns:
Formatted Redis key name
"""
return f"{self.topic}.ins.{scheduler_name}"
def _generate_scheduler_name(self) -> str:
"""
Generate a unique name for this scheduler instance.
Uses hostname/IP and timestamp to create a unique identifier,
then registers it in Redis with TTL.
Returns:
Unique scheduler name string
"""
try:
_, name = utils.get_hostname_ip()
except Exception as e:
llm_logger.warning(
f"Scheduler encountered an error while resolving the IP address. {e}")
name = str(uuid.uuid4())
size = len(name)
now = time.time()
local_time = datetime.fromtimestamp(now)
formatted_time = local_time.strftime(
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
count = 1
while True:
if self.client.set(self._instance_name(name), formatted_time, ex=self.ttl, nx=True):
break
name = f"{name[:size]}:{count}"
count += 1
return name
def _keep_alive(self):
"""
Background thread that periodically updates the scheduler's TTL in Redis.
Runs in a loop with interval of TTL/2 to maintain instance registration.
"""
interval_time = self.ttl / 2
while True:
try:
now = time.time()
local_time = datetime.fromtimestamp(now)
formatted_time = local_time.strftime(
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
self.client.set(self._instance_name(self.name),
formatted_time, ex=self.ttl)
except Exception as e:
llm_logger.error(f"Scheduler keep alive failed: {e}")
interval_time = self.ttl / 10
time.sleep(interval_time)
interval_time = self.ttl / 2
def _scheduler_name_from_request_queue(self, request_queue: str) -> str:
"""
Extract scheduler name from a request queue name.
Args:
request_queue: Full request queue name
Returns:
The scheduler name portion of the queue name
"""
prefix_len = len(f"{self.topic}.req.")
return request_queue[prefix_len:]
def _request_queue_name(self, scheduler_name: Optional[str] = None) -> str:
"""
Generate the Redis request queue name for a scheduler.
Args:
scheduler_name: Optional specific scheduler name, defaults to current instance
Returns:
Formatted request queue name
"""
if scheduler_name is None:
return f"{self.topic}.req.{self.name}"
return f"{self.topic}.req.{scheduler_name}"
def _response_queue_name(self, scheduler_name: Optional[str] = None) -> str:
"""
Generate the Redis response queue name for a scheduler.
Args:
scheduler_name: Optional specific scheduler name, defaults to current instance
Returns:
Formatted response queue name
"""
if scheduler_name is None:
return f"{self.topic}.resp.{self.name}"
return f"{self.topic}.resp.{scheduler_name}"
def _load_table_name(self, request_queue_name: Optional[str] = None, slot: Optional[int] = None) -> str:
"""
Get the Redis sorted set name used for load balancing.
Returns:
The load score key name
"""
if request_queue_name is None:
request_queue_name = self._request_queue_name()
if slot is None:
slot = self._get_hash_slot(
request_queue_name) % self.load_shrads_num
else:
slot %= self.load_shrads_num
return f"{self.topic}.load.{slot}"
@staticmethod
def calc_required_blocks(token_num, block_size):
"""calculate required blocks for given token number"""
"""
Calculate the number of blocks needed for a given number of tokens.
Args:
token_num: Number of tokens
block_size: Size of each block
Returns:
Number of blocks required (rounded up)
"""
return (token_num + block_size - 1) // block_size
def _put_requests_worker(self, tasks: List[Tuple[str, Request]]) -> List[Tuple[str, Optional[str]]]:
@staticmethod
def _mark_request(request: ScheduledRequest):
"""
add requests to shared cache
Mark a stolen request with the original queue name.
Args:
request: The request to mark
"""
requests: List[ScheduledRequest] = [
ScheduledRequest(request) for _, request in tasks]
request.request_id = f"mark<{request.request_queue_name}>{request.request_id}"
# check the uniqueness of the request_id
valid_requests: List[ScheduledRequest] = list()
duplicated_ids: List[str] = list()
for request in requests:
unique_key = self._unique_key_name(request.id)
if self.client.set(unique_key, "", ex=self.ttl, nx=True):
valid_requests.append(request)
else:
duplicated_ids.append(request.id)
@staticmethod
def _unmark_response(response: ScheduledResponse, request_queue_name: str):
"""
Remove marking from a response that came from a stolen request.
# add to request queue
serialized_requests = [request.serialize()
for request in valid_requests]
self.client.rpush(self._request_queue_name(), *serialized_requests)
llm_logger.info(
f"Scheduler has put some requests: {[request.id for request in valid_requests]}")
main_process_metrics.num_requests_waiting.inc(len(valid_requests))
Args:
response: The response to unmark
request_queue_name: Original request queue name
"""
mark = f"mark<{request_queue_name}>"
if not response.request_id.startswith(mark):
return
response.request_id = response.request_id[len(mark):]
if len(duplicated_ids) > 0:
def _put_requests_worker(self, tasks: List[Task]) -> List[Task]:
"""
Worker method that adds requests to the shared Redis cache.
Args:
tasks: List of tasks containing requests to enqueue
Returns:
List of processed tasks (some may be marked as duplicates)
"""
duplicate = False
requests: List[ScheduledRequest] = []
with self.mutex:
for task in tasks:
request = ScheduledRequest(
task.raw, self._request_queue_name(), self._response_queue_name())
task.raw = None
if request.request_id in self.local_responses:
task.reason = "duplicate request_id"
duplicate = True
continue
requests.append(request)
self.local_responses[request.request_id] = []
if len(requests) > 0:
serialized_requests = [request.serialize() for request in requests]
self.client.rpush(self._request_queue_name(), *
serialized_requests, ttl=self.ttl)
self.client.zincrby(self._load_table_name(),
len(serialized_requests), self.name,
rem_amount=0, ttl=self.ttl)
llm_logger.info(
f"Scheduler has enqueued some requests: {requests}")
if duplicate:
llm_logger.warning(
f"Scheduler has received some duplicated requests: {duplicated_ids}")
results = [(request.id, None) for request in valid_requests]
results += [(request_id, "duplicated request_id")
for request_id in duplicated_ids]
return results
"Scheduler has received some duplicated requests: "
f"{[task for task in tasks if task.reason is not None]}")
return tasks
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
"""
add requests to scheduler
Public method to add new requests to the scheduler.
Args:
requests: List of Request objects to schedule
Returns:
List of tuples containing (request_id, error_reason) for each request
"""
tasks: List[Tuple[str, Request]] = [
(request.request_id, request) for request in requests]
self.put_request_workers.put_tasks(tasks)
return self.put_request_workers.get_results(10, 0.005)
tasks: List[Task] = []
for request in requests:
task = Task(request.request_id, request)
tasks.append(task)
self.put_requests_workers.add_tasks(tasks)
results = self.put_requests_workers.get_results(10, 0.001)
return [(result.id, result.reason) for result in results]
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
max_num_batched_tokens, batch=1) -> List[Request]:
"""
get requests blocked from shared cache
Get requests from the shared cache based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
@@ -137,161 +381,410 @@ class GlobalScheduler(object):
f"max_num_batched_tokens={max_num_batched_tokens}")
return []
mini_batch = (batch + 1) // 2
batches = []
piece = (batch + 1) // 2
while batch > 0:
batch -= piece
if batch >= 0:
batches.append(piece)
else:
batches.append(piece + batch)
for _ in range(2):
if batch >= mini_batch:
batches.append(mini_batch)
batch -= mini_batch
continue
serialized_requests = []
if batch > 0:
batches.append(batch)
batch = 0
local_request_queue_name = self._request_queue_name()
serialized_requests: List[Tuple[str, bytes]] = []
for bs in batches:
bs_data = self.client.lpop(self._request_queue_name(), bs)
if bs_data is None:
elements = self.client.lpop(
local_request_queue_name, bs, ttl=self.ttl)
if elements is None:
break
serialized_requests += bs_data
self.client.zincrby(self._load_table_name(), -
len(elements), self.name, rem_amount=0, ttl=self.ttl)
serialized_requests += [(local_request_queue_name, element)
for element in elements]
extend_scheduler_names = []
if len(serialized_requests) == 0 and len(batches) > 0:
for _ in range(min(5, self.load_shrads_num)):
serialized_members = self.client.zrangebyscore(
self._load_table_name(
slot=self.load_slot_for_getting_request),
self.min_load_score,
float("+inf"),
start=self.load_start,
num=self.load_num)
self.load_slot_for_getting_request += 1
if len(serialized_members) > 0:
break
members = [member.decode("utf-8") for member in serialized_members]
if len(members) > 0:
extend_scheduler_names = random.sample(
members, k=min(10, len(members)))
extend_scheduler_names = [
name for name in extend_scheduler_names if name != self.name]
# find lucky one
if len(extend_scheduler_names) > 0:
lucky = random.choice(extend_scheduler_names)
lucky_request_queue_name = self._request_queue_name(lucky)
elements = self.client.lpop(lucky_request_queue_name, batches[0])
if elements is not None and len(elements) > 0:
self.client.zincrby(
self._load_table_name(
request_queue_name=lucky_request_queue_name),
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
serialized_requests += [(lucky_request_queue_name, element)
for element in elements]
llm_logger.info(
f"Scheduler {self.name} has stolen some requests from another lucky one. "
f"(name={lucky} num={len(serialized_requests)})")
else:
exist_num = self.client.exists(self._instance_name(lucky))
if exist_num == 0:
if self.client.zrem(
self._load_table_name(
request_queue_name=lucky_request_queue_name),
lucky):
llm_logger.info(
f"Scheduler {lucky} has been removed")
# blocked read
if len(serialized_requests) == 0:
blocked_data = self.client.blpop(
self._request_queue_name(), self.wait_request_timeout)
if blocked_data is None:
return []
serialized_requests = blocked_data[1:]
request_queue_names = [local_request_queue_name]
request_queue_names += [
self._request_queue_name(name) for name in extend_scheduler_names]
element = self.client.blpop(
request_queue_names, self.blpop_request_timeout)
if element is None:
return []
request_queue_name = element[0].decode("utf-8")
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
self.client.zincrby(
self._load_table_name(request_queue_name=request_queue_name),
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
serialized_requests.append((request_queue_name, element[1]))
if scheduler_name != self.name:
llm_logger.info(
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
long_partial_requests = 0
short_partial_requests = 0
required_total_blocks = 0
current_prefill_tokens = 0
remaining_request = []
requests: List[Request] = []
for serialized_request in serialized_requests:
remaining_request: List[Tuple[str, bytes]] = []
scheduled_requests: List[ScheduledRequest] = []
for request_queue_name, serialized_request in serialized_requests:
if len(remaining_request) > 0:
remaining_request.append(serialized_request)
remaining_request.append(
(request_queue_name, serialized_request))
continue
request: ScheduledRequest = ScheduledRequest.unserialize(
serialized_request)
if (time.time() - request.scheduled_time) > self.ttl:
llm_logger.info(
f"Request has expired when getting a request from the scheduler: {[request.id]}")
continue
required_input_blocks = self.calc_required_blocks(
request.size, block_size)
current_prefill_tokens += request.size
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
remaining_request.append(serialized_request)
continue
requests.append(request.raw)
request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
remaining_request.append(
(request_queue_name, serialized_request))
continue
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
remaining_request.append(
(request_queue_name, serialized_request))
continue
else:
short_partial_requests += 1
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
remaining_request.append(
(request_queue_name, serialized_request))
continue
else:
if current_prefill_tokens > max_num_batched_tokens:
remaining_request.append(
(request_queue_name, serialized_request))
continue
scheduled_requests.append(request)
if len(scheduled_requests) > 0:
with self.mutex:
for request in scheduled_requests:
if request.request_queue_name == local_request_queue_name:
continue
self._mark_request(request)
if request.request_id not in self.stolen_requests:
self.stolen_requests[request.request_id] = request
continue
llm_logger.error(
f"Scheduler has received a duplicate request from others: {request}")
requests: List[Request] = [
request.raw for request in scheduled_requests]
if len(remaining_request) > 0:
self.client.lpush(self._request_queue_name(), *remaining_request)
group: Dict[str, List] = dict()
for request_queue_name, serialized_request in remaining_request:
if request_queue_name not in group:
group[request_queue_name] = []
group[request_queue_name].append(serialized_request)
for request_queue_name, serialized_requests in group.items():
self.client.lpush(request_queue_name, *
serialized_requests)
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
self.client.zincrby(
self._load_table_name(
request_queue_name=request_queue_name),
len(serialized_requests), scheduler_name, ttl=self.ttl)
llm_logger.info(
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
if len(requests) == 0:
llm_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
if len(requests) > 0:
llm_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
main_process_metrics.num_requests_running.inc(len(requests))
main_process_metrics.num_requests_waiting.dec(len(requests))
return requests
def _put_results_worker(self, tasks: List[Tuple[str, RequestOutput]]):
def _put_results_worker(self, tasks: List[Task]):
"""
add tasks to shared cache
"""
responses: List[ScheduledResponse] = [
ScheduledResponse(result) for _, result in tasks]
sorted_responses = sorted(
responses, key=lambda response: f"{response.id}.{response.index}")
Worker method that adds task results to the appropriate queues.
finished_responses = [
response.id for response in responses if response.finished]
if len(finished_responses) > 0:
Args:
tasks: List of completed tasks with results
"""
# count = 0 # for test
with self.mutex:
local_request_ids = set(self.local_responses.keys())
stolen_request_id_request_queue = dict()
stolen_request_id_response_queue = dict()
for request_id, request in self.stolen_requests.items():
stolen_request_id_request_queue[request_id] = request.request_queue_name
stolen_request_id_response_queue[request_id] = request.response_queue_name
finished_request_ids: List[str] = list()
local_responses: Dict[str, List[ScheduledResponse]] = dict()
stolen_responses: Dict[str, List[bytes]] = dict()
for task in tasks:
response = ScheduledResponse(task.raw)
if response.finished:
finished_request_ids.append(response.request_id)
if response.request_id in local_request_ids:
if response.request_id not in local_responses:
local_responses[response.request_id] = []
local_responses[response.request_id].append(response)
continue
if response.request_id in stolen_request_id_request_queue:
response_queue_name = stolen_request_id_response_queue[response.request_id]
request_queue_name = stolen_request_id_request_queue[response.request_id]
self._unmark_response(response, request_queue_name)
if response_queue_name not in stolen_responses:
stolen_responses[response_queue_name] = []
stolen_responses[response_queue_name].append(
response.serialize())
continue
llm_logger.error(
f"Scheduler has recieved a non-existent response from engine: {[response]}")
with self.mutex:
for request_id, responses in local_responses.items():
self.local_responses[request_id] += responses
# count += len(responses) # for test
for request_id in finished_request_ids:
if request_id in self.stolen_requests:
del self.stolen_requests[request_id]
if len(local_responses) > 0:
self.local_response_not_empty.notify_all()
if len(finished_request_ids) > 0:
llm_logger.info(
f"Scheduler has received a finished response: {finished_responses}")
f"Scheduler has received some finished responses: {finished_request_ids}")
group = dict()
for response in sorted_responses:
serialized_response = response.serialize()
if response.id not in group:
group[response.id] = [serialized_response]
continue
group[response.id].append(serialized_response)
for response_id, responses in group.items():
ttl = self.client.ttl(self._unique_key_name(
response_id)) - self.remote_write_time
if ttl <= 0:
llm_logger.warning(
f"Scheduler has received a expired response: {[response.id]}")
continue
with self.client.pipeline() as pipe:
pipe.multi()
pipe.rpush(self._response_queue_name(response_id), *responses)
pipe.expire(self._response_queue_name(response_id), ttl)
pipe.execute()
for response_queue_name, responses in stolen_responses.items():
self.client.rpush(response_queue_name, *responses, ttl=self.ttl)
# count += len(responses) # for test
# return [Task("", count)] # for test
def put_results(self, results: List[RequestOutput]):
"""
add results to shared cache
Public method to add processing results back to the scheduler.
Args:
results: List of RequestOutput objects to return
"""
tasks: List[Tuple[str, RequestOutput]] = [
(result.request_id, result) for result in results]
self.put_response_workers.put_tasks(tasks)
tasks: List[Task] = [Task(result.request_id, result)
for result in results]
self.put_results_workers.add_tasks(tasks)
def _get_results_worker(self, tasks: List[Tuple[str, str]]) -> List[Tuple[str, List[ScheduledResponse]]]:
# ---- for test ----
# task_results = self.put_results_workers.get_results(10, 0.001)
# amount = 0
# for task_result in task_results:
# amount += task_result.raw
# return amount
# ---- for test ----
def _get_results_worker(self):
"""
get results blocked from shared cache
Background worker that continuously fetches results from Redis.
Handles both bulk and blocking operations for efficiency.
Runs in an infinite loop until scheduler shutdown.
"""
if len(tasks) != 1:
raise ValueError(
f"Tasks size of _get_results_worker must be 1. ({len(tasks)})")
while True:
try:
serialized_responses = self.client.lpop(
self._response_queue_name(), 300, ttl=self.ttl)
task_id, request_id = tasks[0]
key = self._response_queue_name(request_id)
size = self.client.llen(key)
size = min(size, self.response_max_batch)
if serialized_responses is None or len(serialized_responses) == 0:
element = self.client.blpop(
[self._response_queue_name()], self.blpop_response_timeout)
if element is None or len(element) == 0:
continue
serialized_responses = [element[1]]
serialized_responses = None
if size > 0:
serialized_responses = self.client.lpop(key, size)
responses: Dict[str, List[ScheduledResponse]] = dict()
for serialized_response in serialized_responses:
response = ScheduledResponse.unserialize(
serialized_response)
if response.request_id not in responses:
responses[response.request_id] = []
responses[response.request_id].append(response)
if serialized_responses is None or len(serialized_responses) == 0:
blocked_data = self.client.blpop(key, self.wait_response_timeout)
if blocked_data is None:
return []
serialized_responses = blocked_data[1:]
with self.mutex:
for request_id, contents in responses.items():
if request_id not in self.local_responses:
llm_logger.error(
"Scheduler has received some non-existent response from the queue. "
f"response:{contents} queue:{self._response_queue_name()}")
continue
self.local_responses[request_id] += contents
self.local_response_not_empty.notify_all()
except Exception as e:
llm_logger.error(f"Scheduler get_results_worker exception: {e} "
f"traceback: {traceback.format_exc()}")
output = [(task_id, [])]
for serialized_response in serialized_responses:
response = ScheduledResponse.unserialize(serialized_response)
output[0][1].append(response)
return output
def get_results(self, request_ids: List[str]) -> Dict[str, RequestOutput]:
def get_results(self) -> Dict[str, List[RequestOutput]]:
"""
get results blocked from scheduler.
"""
tasks = [(request_id, request_id) for request_id in request_ids]
self.get_response_workers.put_tasks(tasks, deduplication=True)
batch_responses: List[Tuple[str, List[ScheduledResponse]]] = self.get_response_workers.get_results(
10, self.wait_response_timeout)
Retrieve all available results from the distributed scheduler.
results = dict()
for _, responses in batch_responses:
for response in responses:
if response.id not in results:
results[response.id] = []
results[response.id].append(response)
if response.finished:
This method:
- Waits for new responses using a condition variable (timeout=0.001s)
- Returns all currently available responses
- Automatically removes completed requests from local tracking
- Logs finished requests
Behavior Details:
1. For first call with less than 64 pending responses, returns empty dict
2. Subsequent calls return all available responses
3. Uses thread-safe operations with condition variables
4. Automatically cleans up completed request tracking
Returns:
Dict[str, List[RequestOutput]]:
A dictionary where:
- Key is the request ID
- Value is a list of RequestOutput objects for that request
Completed requests are automatically removed from tracking
Note:
- Thread-safe operation using condition variables
- Short timeout avoids blocking while maintaining responsiveness
- First call may return empty to batch small responses
- Automatically logs finished requests via llm_logger
"""
first = True
def _get_results() -> Dict[str, List[ScheduledResponse]]:
nonlocal first
responses: Dict[str, List[ScheduledResponse]] = dict()
count = 0
for _, contents in self.local_responses.items():
count += len(contents)
if first and count < 64:
first = False
return responses
request_ids = list(self.local_responses.keys())
for request_id in request_ids:
responses[request_id] = self.local_responses[request_id]
self.local_responses[request_id] = []
return responses
with self.local_response_not_empty:
responses: Dict[str, List[ScheduledResponse]] = self.local_response_not_empty.wait_for(
_get_results, 0.001)
results: Dict[str, List[RequestOutput]] = dict()
for request_id, resps in responses.items():
finished = False
results[request_id] = []
for resp in resps:
results[request_id].append(resp.raw)
finished |= resp.finished
if finished:
del self.local_responses[request_id]
llm_logger.info(
f"Scheduler has pulled a finished response: {[response.id]}")
f"Scheduler has pulled a finished response: {[request_id]}")
return results
request_ids = list(results.keys())
for request_id in request_ids:
results[request_id] = sorted(
results[request_id], key=lambda response: f"{response.id}.{response.index}")
results[request_id] = [
result.raw for result in results[request_id]]
return results
def reset(self):
"""
Reset the scheduler to its initial state by:
1. Clearing all Redis queues associated with this scheduler instance
2. Removing this instance from the load balancing table
3. Clearing in-memory tracking of responses and stolen requests
This method is thread-safe and should be called when:
- The scheduler needs to be cleanly restarted
- Recovering from critical errors
- Preparing for graceful shutdown
Effects:
- Deletes the request and response queues in Redis
- Removes this scheduler's entry from the load balancing sorted set
- Clears the local_responses dictionary tracking pending responses
- Clears the stolen_requests dictionary tracking requests taken from other schedulers
Note:
- Uses the scheduler's mutex to ensure thread safety
- Does not affect other scheduler instances in the cluster
- After reset, the scheduler will need to be reinitialized to be usable again
"""
with self.mutex:
self.client.delete(self._request_queue_name(),
self._response_queue_name())
self.client.zrem(self._load_table_name(), self.name)
self.local_responses = dict()
self.stolen_requests = dict()
llm_logger.info("Scheduler has been reset")