mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -16,19 +16,31 @@
|
||||
|
||||
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
import traceback
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
import random
|
||||
import uuid
|
||||
import crcmod
|
||||
from redis import ConnectionPool
|
||||
from fastdeploy.scheduler.storage import AdaptedRedis
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.scheduler.workers import Workers
|
||||
from fastdeploy.scheduler.workers import Workers, Task
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.scheduler import utils
|
||||
|
||||
|
||||
class GlobalScheduler(object):
|
||||
"""
|
||||
GlobalScheduler class
|
||||
A distributed task scheduler that manages request/response queues using Redis.
|
||||
|
||||
This class provides functionality for:
|
||||
- Enqueuing and dequeuing requests
|
||||
- Load balancing across multiple scheduler instances
|
||||
- Handling request/response lifecycle
|
||||
- Maintaining worker health checks
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -38,96 +50,328 @@ class GlobalScheduler(object):
|
||||
password: Optional[str],
|
||||
topic: str,
|
||||
ttl: int,
|
||||
remote_write_time: int,
|
||||
wait_response_timeout: float
|
||||
min_load_score: float,
|
||||
load_shrads_num: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
):
|
||||
"""
|
||||
Initialize the GlobalScheduler with Redis connection and scheduling parameters.
|
||||
|
||||
Args:
|
||||
host: Redis server hostname
|
||||
port: Redis server port
|
||||
db: Redis database number
|
||||
password: Optional password for Redis authentication
|
||||
topic: Base topic name for queue namespacing
|
||||
ttl: Time-to-live in seconds for Redis keys
|
||||
min_load_score: Minimum load score for task assignment
|
||||
load_shrads_num: Number of shards for load balancing table
|
||||
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||
max_num_partial_prefills: Maximum number of partial prefills allowed
|
||||
max_long_partial_prefills: Maximum number of long partial prefills allowed
|
||||
long_prefill_token_threshold: Token count threshold for long prefills
|
||||
|
||||
Initializes:
|
||||
- Redis connection pool and client
|
||||
- Worker threads for request/response handling
|
||||
- Load balancing and request stealing mechanisms
|
||||
- Response tracking structures
|
||||
"""
|
||||
|
||||
self.topic = topic
|
||||
self.ttl = ttl
|
||||
self.remote_write_time = remote_write_time
|
||||
self.wait_response_timeout = 1.0 if wait_response_timeout < 1.0 else wait_response_timeout
|
||||
self.wait_request_timeout = 10
|
||||
self.min_load_score = min_load_score
|
||||
self.load_shrads_num = load_shrads_num
|
||||
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
|
||||
self.blpop_request_timeout = 2
|
||||
self.blpop_response_timeout = 10
|
||||
|
||||
self.crc16_mutex = threading.Lock()
|
||||
self.crc16 = crcmod.predefined.Crc('ccitt-false')
|
||||
self.load_slot_for_getting_request = 0
|
||||
self.load_start = 0 # const
|
||||
self.load_num = 50 # const
|
||||
|
||||
connection_pool = ConnectionPool(
|
||||
host=host, port=port, db=db, password=password, max_connections=10)
|
||||
self.client = AdaptedRedis(connection_pool=connection_pool)
|
||||
|
||||
self.put_request_workers = Workers(
|
||||
"put_request_worker", self._put_requests_worker, max_batch_size=5)
|
||||
self.put_request_workers.start(size=1)
|
||||
self.name = self._generate_scheduler_name()
|
||||
self.keep_alive_workers = threading.Thread(
|
||||
target=self._keep_alive, daemon=True)
|
||||
self.keep_alive_workers.start()
|
||||
|
||||
self.put_response_workers = Workers(
|
||||
"put_response_worker", self._put_results_worker, max_batch_size=50)
|
||||
self.put_response_workers.start(size=1)
|
||||
self.put_requests_workers = Workers(
|
||||
"put_requests_workers", self._put_requests_worker, 20)
|
||||
self.put_requests_workers.start(1)
|
||||
|
||||
self.get_response_workers = Workers(
|
||||
"get_response_worker", self._get_results_worker, max_batch_size=1)
|
||||
self.get_response_workers.start(size=5)
|
||||
self.response_max_batch = 50
|
||||
self.put_results_workers = Workers(
|
||||
"put_results_workers", self._put_results_worker, 300)
|
||||
self.put_results_workers.start(1)
|
||||
|
||||
llm_logger.info(f"Scheduler: redis version is {self.client.version}")
|
||||
self.mutex = threading.Lock()
|
||||
self.local_response_not_empty = threading.Condition(self.mutex)
|
||||
self.local_responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
self.stolen_requests: Dict[str, ScheduledRequest] = dict()
|
||||
|
||||
def _request_queue_name(self):
|
||||
return f"{self.topic}.request"
|
||||
self.get_response_workers = threading.Thread(
|
||||
target=self._get_results_worker, daemon=True)
|
||||
self.get_response_workers.start()
|
||||
|
||||
def _response_queue_name(self, id: str):
|
||||
return f"{self.topic}.response.{id}"
|
||||
llm_logger.info(
|
||||
f"Scheduler: name={self.name} redis_version={self.client.version}")
|
||||
|
||||
def _unique_key_name(self, id: str):
|
||||
return f"{self.topic}.unique.{id}"
|
||||
def _get_hash_slot(self, data: str) -> int:
|
||||
data = data.encode("utf-8")
|
||||
with self.crc16_mutex:
|
||||
self.crc16.update(data)
|
||||
crc_value = self.crc16.crcValue
|
||||
self.crc16.crcValue = self.crc16.initCrc
|
||||
return crc_value
|
||||
|
||||
def _instance_name(self, scheduler_name: str) -> str:
|
||||
"""
|
||||
Generate the Redis key name for a scheduler instance.
|
||||
|
||||
Args:
|
||||
scheduler_name: Name of the scheduler instance
|
||||
|
||||
Returns:
|
||||
Formatted Redis key name
|
||||
"""
|
||||
return f"{self.topic}.ins.{scheduler_name}"
|
||||
|
||||
def _generate_scheduler_name(self) -> str:
|
||||
"""
|
||||
Generate a unique name for this scheduler instance.
|
||||
|
||||
Uses hostname/IP and timestamp to create a unique identifier,
|
||||
then registers it in Redis with TTL.
|
||||
|
||||
Returns:
|
||||
Unique scheduler name string
|
||||
"""
|
||||
try:
|
||||
_, name = utils.get_hostname_ip()
|
||||
except Exception as e:
|
||||
llm_logger.warning(
|
||||
f"Scheduler encountered an error while resolving the IP address. {e}")
|
||||
name = str(uuid.uuid4())
|
||||
|
||||
size = len(name)
|
||||
now = time.time()
|
||||
local_time = datetime.fromtimestamp(now)
|
||||
formatted_time = local_time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
|
||||
|
||||
count = 1
|
||||
while True:
|
||||
if self.client.set(self._instance_name(name), formatted_time, ex=self.ttl, nx=True):
|
||||
break
|
||||
name = f"{name[:size]}:{count}"
|
||||
count += 1
|
||||
return name
|
||||
|
||||
def _keep_alive(self):
|
||||
"""
|
||||
Background thread that periodically updates the scheduler's TTL in Redis.
|
||||
|
||||
Runs in a loop with interval of TTL/2 to maintain instance registration.
|
||||
"""
|
||||
interval_time = self.ttl / 2
|
||||
while True:
|
||||
try:
|
||||
now = time.time()
|
||||
local_time = datetime.fromtimestamp(now)
|
||||
formatted_time = local_time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
|
||||
self.client.set(self._instance_name(self.name),
|
||||
formatted_time, ex=self.ttl)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Scheduler keep alive failed: {e}")
|
||||
interval_time = self.ttl / 10
|
||||
|
||||
time.sleep(interval_time)
|
||||
interval_time = self.ttl / 2
|
||||
|
||||
def _scheduler_name_from_request_queue(self, request_queue: str) -> str:
|
||||
"""
|
||||
Extract scheduler name from a request queue name.
|
||||
|
||||
Args:
|
||||
request_queue: Full request queue name
|
||||
|
||||
Returns:
|
||||
The scheduler name portion of the queue name
|
||||
"""
|
||||
prefix_len = len(f"{self.topic}.req.")
|
||||
return request_queue[prefix_len:]
|
||||
|
||||
def _request_queue_name(self, scheduler_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate the Redis request queue name for a scheduler.
|
||||
|
||||
Args:
|
||||
scheduler_name: Optional specific scheduler name, defaults to current instance
|
||||
|
||||
Returns:
|
||||
Formatted request queue name
|
||||
"""
|
||||
if scheduler_name is None:
|
||||
return f"{self.topic}.req.{self.name}"
|
||||
return f"{self.topic}.req.{scheduler_name}"
|
||||
|
||||
def _response_queue_name(self, scheduler_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate the Redis response queue name for a scheduler.
|
||||
|
||||
Args:
|
||||
scheduler_name: Optional specific scheduler name, defaults to current instance
|
||||
|
||||
Returns:
|
||||
Formatted response queue name
|
||||
"""
|
||||
if scheduler_name is None:
|
||||
return f"{self.topic}.resp.{self.name}"
|
||||
return f"{self.topic}.resp.{scheduler_name}"
|
||||
|
||||
def _load_table_name(self, request_queue_name: Optional[str] = None, slot: Optional[int] = None) -> str:
|
||||
"""
|
||||
Get the Redis sorted set name used for load balancing.
|
||||
|
||||
Returns:
|
||||
The load score key name
|
||||
"""
|
||||
if request_queue_name is None:
|
||||
request_queue_name = self._request_queue_name()
|
||||
|
||||
if slot is None:
|
||||
slot = self._get_hash_slot(
|
||||
request_queue_name) % self.load_shrads_num
|
||||
else:
|
||||
slot %= self.load_shrads_num
|
||||
return f"{self.topic}.load.{slot}"
|
||||
|
||||
@staticmethod
|
||||
def calc_required_blocks(token_num, block_size):
|
||||
"""calculate required blocks for given token number"""
|
||||
"""
|
||||
Calculate the number of blocks needed for a given number of tokens.
|
||||
|
||||
Args:
|
||||
token_num: Number of tokens
|
||||
block_size: Size of each block
|
||||
|
||||
Returns:
|
||||
Number of blocks required (rounded up)
|
||||
"""
|
||||
return (token_num + block_size - 1) // block_size
|
||||
|
||||
def _put_requests_worker(self, tasks: List[Tuple[str, Request]]) -> List[Tuple[str, Optional[str]]]:
|
||||
@staticmethod
|
||||
def _mark_request(request: ScheduledRequest):
|
||||
"""
|
||||
add requests to shared cache
|
||||
Mark a stolen request with the original queue name.
|
||||
|
||||
Args:
|
||||
request: The request to mark
|
||||
"""
|
||||
requests: List[ScheduledRequest] = [
|
||||
ScheduledRequest(request) for _, request in tasks]
|
||||
request.request_id = f"mark<{request.request_queue_name}>{request.request_id}"
|
||||
|
||||
# check the uniqueness of the request_id
|
||||
valid_requests: List[ScheduledRequest] = list()
|
||||
duplicated_ids: List[str] = list()
|
||||
for request in requests:
|
||||
unique_key = self._unique_key_name(request.id)
|
||||
if self.client.set(unique_key, "", ex=self.ttl, nx=True):
|
||||
valid_requests.append(request)
|
||||
else:
|
||||
duplicated_ids.append(request.id)
|
||||
@staticmethod
|
||||
def _unmark_response(response: ScheduledResponse, request_queue_name: str):
|
||||
"""
|
||||
Remove marking from a response that came from a stolen request.
|
||||
|
||||
# add to request queue
|
||||
serialized_requests = [request.serialize()
|
||||
for request in valid_requests]
|
||||
self.client.rpush(self._request_queue_name(), *serialized_requests)
|
||||
llm_logger.info(
|
||||
f"Scheduler has put some requests: {[request.id for request in valid_requests]}")
|
||||
main_process_metrics.num_requests_waiting.inc(len(valid_requests))
|
||||
Args:
|
||||
response: The response to unmark
|
||||
request_queue_name: Original request queue name
|
||||
"""
|
||||
mark = f"mark<{request_queue_name}>"
|
||||
if not response.request_id.startswith(mark):
|
||||
return
|
||||
response.request_id = response.request_id[len(mark):]
|
||||
|
||||
if len(duplicated_ids) > 0:
|
||||
def _put_requests_worker(self, tasks: List[Task]) -> List[Task]:
|
||||
"""
|
||||
Worker method that adds requests to the shared Redis cache.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks containing requests to enqueue
|
||||
|
||||
Returns:
|
||||
List of processed tasks (some may be marked as duplicates)
|
||||
"""
|
||||
duplicate = False
|
||||
requests: List[ScheduledRequest] = []
|
||||
with self.mutex:
|
||||
for task in tasks:
|
||||
request = ScheduledRequest(
|
||||
task.raw, self._request_queue_name(), self._response_queue_name())
|
||||
task.raw = None
|
||||
|
||||
if request.request_id in self.local_responses:
|
||||
task.reason = "duplicate request_id"
|
||||
duplicate = True
|
||||
continue
|
||||
requests.append(request)
|
||||
self.local_responses[request.request_id] = []
|
||||
|
||||
if len(requests) > 0:
|
||||
serialized_requests = [request.serialize() for request in requests]
|
||||
self.client.rpush(self._request_queue_name(), *
|
||||
serialized_requests, ttl=self.ttl)
|
||||
self.client.zincrby(self._load_table_name(),
|
||||
len(serialized_requests), self.name,
|
||||
rem_amount=0, ttl=self.ttl)
|
||||
llm_logger.info(
|
||||
f"Scheduler has enqueued some requests: {requests}")
|
||||
|
||||
if duplicate:
|
||||
llm_logger.warning(
|
||||
f"Scheduler has received some duplicated requests: {duplicated_ids}")
|
||||
|
||||
results = [(request.id, None) for request in valid_requests]
|
||||
results += [(request_id, "duplicated request_id")
|
||||
for request_id in duplicated_ids]
|
||||
return results
|
||||
"Scheduler has received some duplicated requests: "
|
||||
f"{[task for task in tasks if task.reason is not None]}")
|
||||
return tasks
|
||||
|
||||
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
add requests to scheduler
|
||||
Public method to add new requests to the scheduler.
|
||||
|
||||
Args:
|
||||
requests: List of Request objects to schedule
|
||||
|
||||
Returns:
|
||||
List of tuples containing (request_id, error_reason) for each request
|
||||
"""
|
||||
tasks: List[Tuple[str, Request]] = [
|
||||
(request.request_id, request) for request in requests]
|
||||
self.put_request_workers.put_tasks(tasks)
|
||||
return self.put_request_workers.get_results(10, 0.005)
|
||||
tasks: List[Task] = []
|
||||
for request in requests:
|
||||
task = Task(request.request_id, request)
|
||||
tasks.append(task)
|
||||
|
||||
self.put_requests_workers.add_tasks(tasks)
|
||||
results = self.put_requests_workers.get_results(10, 0.001)
|
||||
return [(result.id, result.reason) for result in results]
|
||||
|
||||
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
|
||||
max_num_batched_tokens, batch=1) -> List[Request]:
|
||||
"""
|
||||
get requests blocked from shared cache
|
||||
Get requests from the shared cache based on available resources.
|
||||
|
||||
Args:
|
||||
available_blocks: Number of available processing blocks
|
||||
block_size: Size of each processing block
|
||||
reserved_output_blocks: Blocks reserved for output
|
||||
max_num_batched_tokens: Maximum tokens that can be batched
|
||||
batch: Preferred batch size
|
||||
|
||||
Returns:
|
||||
List of Request objects ready for processing
|
||||
"""
|
||||
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
@@ -137,161 +381,410 @@ class GlobalScheduler(object):
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
return []
|
||||
|
||||
mini_batch = (batch + 1) // 2
|
||||
batches = []
|
||||
piece = (batch + 1) // 2
|
||||
while batch > 0:
|
||||
batch -= piece
|
||||
if batch >= 0:
|
||||
batches.append(piece)
|
||||
else:
|
||||
batches.append(piece + batch)
|
||||
for _ in range(2):
|
||||
if batch >= mini_batch:
|
||||
batches.append(mini_batch)
|
||||
batch -= mini_batch
|
||||
continue
|
||||
|
||||
serialized_requests = []
|
||||
if batch > 0:
|
||||
batches.append(batch)
|
||||
batch = 0
|
||||
|
||||
local_request_queue_name = self._request_queue_name()
|
||||
serialized_requests: List[Tuple[str, bytes]] = []
|
||||
for bs in batches:
|
||||
bs_data = self.client.lpop(self._request_queue_name(), bs)
|
||||
if bs_data is None:
|
||||
elements = self.client.lpop(
|
||||
local_request_queue_name, bs, ttl=self.ttl)
|
||||
if elements is None:
|
||||
break
|
||||
serialized_requests += bs_data
|
||||
self.client.zincrby(self._load_table_name(), -
|
||||
len(elements), self.name, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests += [(local_request_queue_name, element)
|
||||
for element in elements]
|
||||
|
||||
extend_scheduler_names = []
|
||||
if len(serialized_requests) == 0 and len(batches) > 0:
|
||||
for _ in range(min(5, self.load_shrads_num)):
|
||||
serialized_members = self.client.zrangebyscore(
|
||||
self._load_table_name(
|
||||
slot=self.load_slot_for_getting_request),
|
||||
self.min_load_score,
|
||||
float("+inf"),
|
||||
start=self.load_start,
|
||||
num=self.load_num)
|
||||
self.load_slot_for_getting_request += 1
|
||||
if len(serialized_members) > 0:
|
||||
break
|
||||
|
||||
members = [member.decode("utf-8") for member in serialized_members]
|
||||
if len(members) > 0:
|
||||
extend_scheduler_names = random.sample(
|
||||
members, k=min(10, len(members)))
|
||||
extend_scheduler_names = [
|
||||
name for name in extend_scheduler_names if name != self.name]
|
||||
|
||||
# find lucky one
|
||||
if len(extend_scheduler_names) > 0:
|
||||
lucky = random.choice(extend_scheduler_names)
|
||||
lucky_request_queue_name = self._request_queue_name(lucky)
|
||||
|
||||
elements = self.client.lpop(lucky_request_queue_name, batches[0])
|
||||
if elements is not None and len(elements) > 0:
|
||||
self.client.zincrby(
|
||||
self._load_table_name(
|
||||
request_queue_name=lucky_request_queue_name),
|
||||
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests += [(lucky_request_queue_name, element)
|
||||
for element in elements]
|
||||
llm_logger.info(
|
||||
f"Scheduler {self.name} has stolen some requests from another lucky one. "
|
||||
f"(name={lucky} num={len(serialized_requests)})")
|
||||
else:
|
||||
exist_num = self.client.exists(self._instance_name(lucky))
|
||||
if exist_num == 0:
|
||||
if self.client.zrem(
|
||||
self._load_table_name(
|
||||
request_queue_name=lucky_request_queue_name),
|
||||
lucky):
|
||||
llm_logger.info(
|
||||
f"Scheduler {lucky} has been removed")
|
||||
|
||||
# blocked read
|
||||
if len(serialized_requests) == 0:
|
||||
blocked_data = self.client.blpop(
|
||||
self._request_queue_name(), self.wait_request_timeout)
|
||||
if blocked_data is None:
|
||||
return []
|
||||
serialized_requests = blocked_data[1:]
|
||||
request_queue_names = [local_request_queue_name]
|
||||
request_queue_names += [
|
||||
self._request_queue_name(name) for name in extend_scheduler_names]
|
||||
|
||||
element = self.client.blpop(
|
||||
request_queue_names, self.blpop_request_timeout)
|
||||
if element is None:
|
||||
return []
|
||||
request_queue_name = element[0].decode("utf-8")
|
||||
scheduler_name = self._scheduler_name_from_request_queue(
|
||||
request_queue_name)
|
||||
self.client.zincrby(
|
||||
self._load_table_name(request_queue_name=request_queue_name),
|
||||
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests.append((request_queue_name, element[1]))
|
||||
if scheduler_name != self.name:
|
||||
llm_logger.info(
|
||||
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
|
||||
|
||||
long_partial_requests = 0
|
||||
short_partial_requests = 0
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
remaining_request = []
|
||||
requests: List[Request] = []
|
||||
for serialized_request in serialized_requests:
|
||||
remaining_request: List[Tuple[str, bytes]] = []
|
||||
scheduled_requests: List[ScheduledRequest] = []
|
||||
for request_queue_name, serialized_request in serialized_requests:
|
||||
if len(remaining_request) > 0:
|
||||
remaining_request.append(serialized_request)
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
request: ScheduledRequest = ScheduledRequest.unserialize(
|
||||
serialized_request)
|
||||
if (time.time() - request.scheduled_time) > self.ttl:
|
||||
llm_logger.info(
|
||||
f"Request has expired when getting a request from the scheduler: {[request.id]}")
|
||||
continue
|
||||
|
||||
required_input_blocks = self.calc_required_blocks(
|
||||
request.size, block_size)
|
||||
current_prefill_tokens += request.size
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
||||
remaining_request.append(serialized_request)
|
||||
continue
|
||||
requests.append(request.raw)
|
||||
request.prompt_tokens_ids_len, block_size)
|
||||
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
|
||||
if required_total_blocks > available_blocks:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
if self.enable_chunked_prefill:
|
||||
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.max_long_partial_prefills:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
scheduled_requests.append(request)
|
||||
|
||||
if len(scheduled_requests) > 0:
|
||||
with self.mutex:
|
||||
for request in scheduled_requests:
|
||||
if request.request_queue_name == local_request_queue_name:
|
||||
continue
|
||||
|
||||
self._mark_request(request)
|
||||
if request.request_id not in self.stolen_requests:
|
||||
self.stolen_requests[request.request_id] = request
|
||||
continue
|
||||
|
||||
llm_logger.error(
|
||||
f"Scheduler has received a duplicate request from others: {request}")
|
||||
|
||||
requests: List[Request] = [
|
||||
request.raw for request in scheduled_requests]
|
||||
if len(remaining_request) > 0:
|
||||
self.client.lpush(self._request_queue_name(), *remaining_request)
|
||||
group: Dict[str, List] = dict()
|
||||
for request_queue_name, serialized_request in remaining_request:
|
||||
if request_queue_name not in group:
|
||||
group[request_queue_name] = []
|
||||
group[request_queue_name].append(serialized_request)
|
||||
|
||||
for request_queue_name, serialized_requests in group.items():
|
||||
self.client.lpush(request_queue_name, *
|
||||
serialized_requests)
|
||||
scheduler_name = self._scheduler_name_from_request_queue(
|
||||
request_queue_name)
|
||||
self.client.zincrby(
|
||||
self._load_table_name(
|
||||
request_queue_name=request_queue_name),
|
||||
len(serialized_requests), scheduler_name, ttl=self.ttl)
|
||||
|
||||
llm_logger.info(
|
||||
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
|
||||
if len(requests) == 0:
|
||||
llm_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
|
||||
|
||||
if len(requests) > 0:
|
||||
llm_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
||||
main_process_metrics.num_requests_running.inc(len(requests))
|
||||
main_process_metrics.num_requests_waiting.dec(len(requests))
|
||||
return requests
|
||||
|
||||
def _put_results_worker(self, tasks: List[Tuple[str, RequestOutput]]):
|
||||
def _put_results_worker(self, tasks: List[Task]):
|
||||
"""
|
||||
add tasks to shared cache
|
||||
"""
|
||||
responses: List[ScheduledResponse] = [
|
||||
ScheduledResponse(result) for _, result in tasks]
|
||||
sorted_responses = sorted(
|
||||
responses, key=lambda response: f"{response.id}.{response.index}")
|
||||
Worker method that adds task results to the appropriate queues.
|
||||
|
||||
finished_responses = [
|
||||
response.id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
Args:
|
||||
tasks: List of completed tasks with results
|
||||
"""
|
||||
# count = 0 # for test
|
||||
|
||||
with self.mutex:
|
||||
local_request_ids = set(self.local_responses.keys())
|
||||
|
||||
stolen_request_id_request_queue = dict()
|
||||
stolen_request_id_response_queue = dict()
|
||||
for request_id, request in self.stolen_requests.items():
|
||||
stolen_request_id_request_queue[request_id] = request.request_queue_name
|
||||
stolen_request_id_response_queue[request_id] = request.response_queue_name
|
||||
|
||||
finished_request_ids: List[str] = list()
|
||||
local_responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
stolen_responses: Dict[str, List[bytes]] = dict()
|
||||
|
||||
for task in tasks:
|
||||
response = ScheduledResponse(task.raw)
|
||||
if response.finished:
|
||||
finished_request_ids.append(response.request_id)
|
||||
|
||||
if response.request_id in local_request_ids:
|
||||
if response.request_id not in local_responses:
|
||||
local_responses[response.request_id] = []
|
||||
local_responses[response.request_id].append(response)
|
||||
continue
|
||||
|
||||
if response.request_id in stolen_request_id_request_queue:
|
||||
response_queue_name = stolen_request_id_response_queue[response.request_id]
|
||||
request_queue_name = stolen_request_id_request_queue[response.request_id]
|
||||
self._unmark_response(response, request_queue_name)
|
||||
|
||||
if response_queue_name not in stolen_responses:
|
||||
stolen_responses[response_queue_name] = []
|
||||
stolen_responses[response_queue_name].append(
|
||||
response.serialize())
|
||||
continue
|
||||
|
||||
llm_logger.error(
|
||||
f"Scheduler has recieved a non-existent response from engine: {[response]}")
|
||||
|
||||
with self.mutex:
|
||||
for request_id, responses in local_responses.items():
|
||||
self.local_responses[request_id] += responses
|
||||
# count += len(responses) # for test
|
||||
|
||||
for request_id in finished_request_ids:
|
||||
if request_id in self.stolen_requests:
|
||||
del self.stolen_requests[request_id]
|
||||
|
||||
if len(local_responses) > 0:
|
||||
self.local_response_not_empty.notify_all()
|
||||
|
||||
if len(finished_request_ids) > 0:
|
||||
llm_logger.info(
|
||||
f"Scheduler has received a finished response: {finished_responses}")
|
||||
f"Scheduler has received some finished responses: {finished_request_ids}")
|
||||
|
||||
group = dict()
|
||||
for response in sorted_responses:
|
||||
serialized_response = response.serialize()
|
||||
if response.id not in group:
|
||||
group[response.id] = [serialized_response]
|
||||
continue
|
||||
group[response.id].append(serialized_response)
|
||||
|
||||
for response_id, responses in group.items():
|
||||
ttl = self.client.ttl(self._unique_key_name(
|
||||
response_id)) - self.remote_write_time
|
||||
if ttl <= 0:
|
||||
llm_logger.warning(
|
||||
f"Scheduler has received a expired response: {[response.id]}")
|
||||
continue
|
||||
|
||||
with self.client.pipeline() as pipe:
|
||||
pipe.multi()
|
||||
pipe.rpush(self._response_queue_name(response_id), *responses)
|
||||
pipe.expire(self._response_queue_name(response_id), ttl)
|
||||
pipe.execute()
|
||||
for response_queue_name, responses in stolen_responses.items():
|
||||
self.client.rpush(response_queue_name, *responses, ttl=self.ttl)
|
||||
# count += len(responses) # for test
|
||||
# return [Task("", count)] # for test
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
add results to shared cache
|
||||
Public method to add processing results back to the scheduler.
|
||||
|
||||
Args:
|
||||
results: List of RequestOutput objects to return
|
||||
"""
|
||||
tasks: List[Tuple[str, RequestOutput]] = [
|
||||
(result.request_id, result) for result in results]
|
||||
self.put_response_workers.put_tasks(tasks)
|
||||
tasks: List[Task] = [Task(result.request_id, result)
|
||||
for result in results]
|
||||
self.put_results_workers.add_tasks(tasks)
|
||||
|
||||
def _get_results_worker(self, tasks: List[Tuple[str, str]]) -> List[Tuple[str, List[ScheduledResponse]]]:
|
||||
# ---- for test ----
|
||||
# task_results = self.put_results_workers.get_results(10, 0.001)
|
||||
# amount = 0
|
||||
# for task_result in task_results:
|
||||
# amount += task_result.raw
|
||||
# return amount
|
||||
# ---- for test ----
|
||||
|
||||
def _get_results_worker(self):
|
||||
"""
|
||||
get results blocked from shared cache
|
||||
Background worker that continuously fetches results from Redis.
|
||||
|
||||
Handles both bulk and blocking operations for efficiency.
|
||||
Runs in an infinite loop until scheduler shutdown.
|
||||
"""
|
||||
if len(tasks) != 1:
|
||||
raise ValueError(
|
||||
f"Tasks size of _get_results_worker must be 1. ({len(tasks)})")
|
||||
while True:
|
||||
try:
|
||||
serialized_responses = self.client.lpop(
|
||||
self._response_queue_name(), 300, ttl=self.ttl)
|
||||
|
||||
task_id, request_id = tasks[0]
|
||||
key = self._response_queue_name(request_id)
|
||||
size = self.client.llen(key)
|
||||
size = min(size, self.response_max_batch)
|
||||
if serialized_responses is None or len(serialized_responses) == 0:
|
||||
element = self.client.blpop(
|
||||
[self._response_queue_name()], self.blpop_response_timeout)
|
||||
if element is None or len(element) == 0:
|
||||
continue
|
||||
serialized_responses = [element[1]]
|
||||
|
||||
serialized_responses = None
|
||||
if size > 0:
|
||||
serialized_responses = self.client.lpop(key, size)
|
||||
responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
for serialized_response in serialized_responses:
|
||||
response = ScheduledResponse.unserialize(
|
||||
serialized_response)
|
||||
if response.request_id not in responses:
|
||||
responses[response.request_id] = []
|
||||
responses[response.request_id].append(response)
|
||||
|
||||
if serialized_responses is None or len(serialized_responses) == 0:
|
||||
blocked_data = self.client.blpop(key, self.wait_response_timeout)
|
||||
if blocked_data is None:
|
||||
return []
|
||||
serialized_responses = blocked_data[1:]
|
||||
with self.mutex:
|
||||
for request_id, contents in responses.items():
|
||||
if request_id not in self.local_responses:
|
||||
llm_logger.error(
|
||||
"Scheduler has received some non-existent response from the queue. "
|
||||
f"response:{contents} queue:{self._response_queue_name()}")
|
||||
continue
|
||||
self.local_responses[request_id] += contents
|
||||
self.local_response_not_empty.notify_all()
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Scheduler get_results_worker exception: {e} "
|
||||
f"traceback: {traceback.format_exc()}")
|
||||
|
||||
output = [(task_id, [])]
|
||||
for serialized_response in serialized_responses:
|
||||
response = ScheduledResponse.unserialize(serialized_response)
|
||||
output[0][1].append(response)
|
||||
return output
|
||||
|
||||
def get_results(self, request_ids: List[str]) -> Dict[str, RequestOutput]:
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
"""
|
||||
get results blocked from scheduler.
|
||||
"""
|
||||
tasks = [(request_id, request_id) for request_id in request_ids]
|
||||
self.get_response_workers.put_tasks(tasks, deduplication=True)
|
||||
batch_responses: List[Tuple[str, List[ScheduledResponse]]] = self.get_response_workers.get_results(
|
||||
10, self.wait_response_timeout)
|
||||
Retrieve all available results from the distributed scheduler.
|
||||
|
||||
results = dict()
|
||||
for _, responses in batch_responses:
|
||||
for response in responses:
|
||||
if response.id not in results:
|
||||
results[response.id] = []
|
||||
results[response.id].append(response)
|
||||
if response.finished:
|
||||
This method:
|
||||
- Waits for new responses using a condition variable (timeout=0.001s)
|
||||
- Returns all currently available responses
|
||||
- Automatically removes completed requests from local tracking
|
||||
- Logs finished requests
|
||||
|
||||
Behavior Details:
|
||||
1. For first call with less than 64 pending responses, returns empty dict
|
||||
2. Subsequent calls return all available responses
|
||||
3. Uses thread-safe operations with condition variables
|
||||
4. Automatically cleans up completed request tracking
|
||||
|
||||
Returns:
|
||||
Dict[str, List[RequestOutput]]:
|
||||
A dictionary where:
|
||||
- Key is the request ID
|
||||
- Value is a list of RequestOutput objects for that request
|
||||
Completed requests are automatically removed from tracking
|
||||
|
||||
Note:
|
||||
- Thread-safe operation using condition variables
|
||||
- Short timeout avoids blocking while maintaining responsiveness
|
||||
- First call may return empty to batch small responses
|
||||
- Automatically logs finished requests via llm_logger
|
||||
"""
|
||||
first = True
|
||||
|
||||
def _get_results() -> Dict[str, List[ScheduledResponse]]:
|
||||
nonlocal first
|
||||
responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
|
||||
count = 0
|
||||
for _, contents in self.local_responses.items():
|
||||
count += len(contents)
|
||||
|
||||
if first and count < 64:
|
||||
first = False
|
||||
return responses
|
||||
|
||||
request_ids = list(self.local_responses.keys())
|
||||
for request_id in request_ids:
|
||||
responses[request_id] = self.local_responses[request_id]
|
||||
self.local_responses[request_id] = []
|
||||
return responses
|
||||
|
||||
with self.local_response_not_empty:
|
||||
responses: Dict[str, List[ScheduledResponse]] = self.local_response_not_empty.wait_for(
|
||||
_get_results, 0.001)
|
||||
|
||||
results: Dict[str, List[RequestOutput]] = dict()
|
||||
for request_id, resps in responses.items():
|
||||
finished = False
|
||||
results[request_id] = []
|
||||
for resp in resps:
|
||||
results[request_id].append(resp.raw)
|
||||
finished |= resp.finished
|
||||
|
||||
if finished:
|
||||
del self.local_responses[request_id]
|
||||
llm_logger.info(
|
||||
f"Scheduler has pulled a finished response: {[response.id]}")
|
||||
f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
return results
|
||||
|
||||
request_ids = list(results.keys())
|
||||
for request_id in request_ids:
|
||||
results[request_id] = sorted(
|
||||
results[request_id], key=lambda response: f"{response.id}.{response.index}")
|
||||
results[request_id] = [
|
||||
result.raw for result in results[request_id]]
|
||||
return results
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the scheduler to its initial state by:
|
||||
1. Clearing all Redis queues associated with this scheduler instance
|
||||
2. Removing this instance from the load balancing table
|
||||
3. Clearing in-memory tracking of responses and stolen requests
|
||||
|
||||
This method is thread-safe and should be called when:
|
||||
- The scheduler needs to be cleanly restarted
|
||||
- Recovering from critical errors
|
||||
- Preparing for graceful shutdown
|
||||
|
||||
Effects:
|
||||
- Deletes the request and response queues in Redis
|
||||
- Removes this scheduler's entry from the load balancing sorted set
|
||||
- Clears the local_responses dictionary tracking pending responses
|
||||
- Clears the stolen_requests dictionary tracking requests taken from other schedulers
|
||||
|
||||
Note:
|
||||
- Uses the scheduler's mutex to ensure thread safety
|
||||
- Does not affect other scheduler instances in the cluster
|
||||
- After reset, the scheduler will need to be reinitialized to be usable again
|
||||
"""
|
||||
with self.mutex:
|
||||
self.client.delete(self._request_queue_name(),
|
||||
self._response_queue_name())
|
||||
self.client.zrem(self._load_table_name(), self.name)
|
||||
self.local_responses = dict()
|
||||
self.stolen_requests = dict()
|
||||
llm_logger.info("Scheduler has been reset")
|
||||
|
Reference in New Issue
Block a user