Files
FastDeploy/fastdeploy/scheduler/global_scheduler.py
2025-06-09 19:20:15 +08:00

298 lines
12 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import List, Optional, Dict, Tuple
import time
from redis import ConnectionPool
from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.workers import Workers
from fastdeploy.utils import llm_logger
class GlobalScheduler(object):
"""
GlobalScheduler class
"""
def __init__(self,
host: str,
port: int,
db: int,
password: Optional[str],
topic: str,
ttl: int,
remote_write_time: int,
wait_response_timeout: float
):
self.topic = topic
self.ttl = ttl
self.remote_write_time = remote_write_time
self.wait_response_timeout = 1.0 if wait_response_timeout < 1.0 else wait_response_timeout
self.wait_request_timeout = 10
connection_pool = ConnectionPool(
host=host, port=port, db=db, password=password, max_connections=10)
self.client = AdaptedRedis(connection_pool=connection_pool)
self.put_request_workers = Workers(
"put_request_worker", self._put_requests_worker, max_batch_size=5)
self.put_request_workers.start(size=1)
self.put_response_workers = Workers(
"put_response_worker", self._put_results_worker, max_batch_size=50)
self.put_response_workers.start(size=1)
self.get_response_workers = Workers(
"get_response_worker", self._get_results_worker, max_batch_size=1)
self.get_response_workers.start(size=5)
self.response_max_batch = 50
llm_logger.info(f"Scheduler: redis version is {self.client.version}")
def _request_queue_name(self):
return f"{self.topic}.request"
def _response_queue_name(self, id: str):
return f"{self.topic}.response.{id}"
def _unique_key_name(self, id: str):
return f"{self.topic}.unique.{id}"
@staticmethod
def calc_required_blocks(token_num, block_size):
"""calculate required blocks for given token number"""
return (token_num + block_size - 1) // block_size
def _put_requests_worker(self, tasks: List[Tuple[str, Request]]) -> List[Tuple[str, Optional[str]]]:
"""
add requests to shared cache
"""
requests: List[ScheduledRequest] = [
ScheduledRequest(request) for _, request in tasks]
# check the uniqueness of the request_id
valid_requests: List[ScheduledRequest] = list()
duplicated_ids: List[str] = list()
for request in requests:
unique_key = self._unique_key_name(request.id)
if self.client.set(unique_key, "", ex=self.ttl, nx=True):
valid_requests.append(request)
else:
duplicated_ids.append(request.id)
# add to request queue
serialized_requests = [request.serialize()
for request in valid_requests]
self.client.rpush(self._request_queue_name(), *serialized_requests)
llm_logger.info(
f"Scheduler has put some requests: {[request.id for request in valid_requests]}")
main_process_metrics.num_requests_waiting.inc(len(valid_requests))
if len(duplicated_ids) > 0:
llm_logger.warning(
f"Scheduler has received some duplicated requests: {duplicated_ids}")
results = [(request.id, None) for request in valid_requests]
results += [(request_id, "duplicated request_id")
for request_id in duplicated_ids]
return results
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
"""
add requests to scheduler
"""
tasks: List[Tuple[str, Request]] = [
(request.request_id, request) for request in requests]
self.put_request_workers.put_tasks(tasks)
return self.put_request_workers.get_results(10, 0.005)
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
max_num_batched_tokens, batch=1) -> List[Request]:
"""
get requests blocked from shared cache
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
return []
batches = []
piece = (batch + 1) // 2
while batch > 0:
batch -= piece
if batch >= 0:
batches.append(piece)
else:
batches.append(piece + batch)
serialized_requests = []
for bs in batches:
bs_data = self.client.lpop(self._request_queue_name(), bs)
if bs_data is None:
break
serialized_requests += bs_data
if len(serialized_requests) == 0:
blocked_data = self.client.blpop(
self._request_queue_name(), self.wait_request_timeout)
if blocked_data is None:
return []
serialized_requests = blocked_data[1:]
required_total_blocks = 0
current_prefill_tokens = 0
remaining_request = []
requests: List[Request] = []
for serialized_request in serialized_requests:
if len(remaining_request) > 0:
remaining_request.append(serialized_request)
continue
request: ScheduledRequest = ScheduledRequest.unserialize(
serialized_request)
if (time.time() - request.scheduled_time) > self.ttl:
llm_logger.info(
f"Request has expired when getting a request from the scheduler: {[request.id]}")
continue
required_input_blocks = self.calc_required_blocks(
request.size, block_size)
current_prefill_tokens += request.size
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
remaining_request.append(serialized_request)
continue
requests.append(request.raw)
if len(remaining_request) > 0:
self.client.lpush(self._request_queue_name(), *remaining_request)
if len(requests) > 0:
llm_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
main_process_metrics.num_requests_running.inc(len(requests))
main_process_metrics.num_requests_waiting.dec(len(requests))
return requests
def _put_results_worker(self, tasks: List[Tuple[str, RequestOutput]]):
"""
add tasks to shared cache
"""
responses: List[ScheduledResponse] = [
ScheduledResponse(result) for _, result in tasks]
sorted_responses = sorted(
responses, key=lambda response: f"{response.id}.{response.index}")
finished_responses = [
response.id for response in responses if response.finished]
if len(finished_responses) > 0:
llm_logger.info(
f"Scheduler has received a finished response: {finished_responses}")
group = dict()
for response in sorted_responses:
serialized_response = response.serialize()
if response.id not in group:
group[response.id] = [serialized_response]
continue
group[response.id].append(serialized_response)
for response_id, responses in group.items():
ttl = self.client.ttl(self._unique_key_name(
response_id)) - self.remote_write_time
if ttl <= 0:
llm_logger.warning(
f"Scheduler has received a expired response: {[response.id]}")
continue
with self.client.pipeline() as pipe:
pipe.multi()
pipe.rpush(self._response_queue_name(response_id), *responses)
pipe.expire(self._response_queue_name(response_id), ttl)
pipe.execute()
def put_results(self, results: List[RequestOutput]):
"""
add results to shared cache
"""
tasks: List[Tuple[str, RequestOutput]] = [
(result.request_id, result) for result in results]
self.put_response_workers.put_tasks(tasks)
def _get_results_worker(self, tasks: List[Tuple[str, str]]) -> List[Tuple[str, List[ScheduledResponse]]]:
"""
get results blocked from shared cache
"""
if len(tasks) != 1:
raise ValueError(
f"Tasks size of _get_results_worker must be 1. ({len(tasks)})")
task_id, request_id = tasks[0]
key = self._response_queue_name(request_id)
size = self.client.llen(key)
size = min(size, self.response_max_batch)
serialized_responses = None
if size > 0:
serialized_responses = self.client.lpop(key, size)
if serialized_responses is None or len(serialized_responses) == 0:
blocked_data = self.client.blpop(key, self.wait_response_timeout)
if blocked_data is None:
return []
serialized_responses = blocked_data[1:]
output = [(task_id, [])]
for serialized_response in serialized_responses:
response = ScheduledResponse.unserialize(serialized_response)
output[0][1].append(response)
return output
def get_results(self, request_ids: List[str]) -> Dict[str, RequestOutput]:
"""
get results blocked from scheduler.
"""
tasks = [(request_id, request_id) for request_id in request_ids]
self.get_response_workers.put_tasks(tasks, deduplication=True)
batch_responses: List[Tuple[str, List[ScheduledResponse]]] = self.get_response_workers.get_results(
10, self.wait_response_timeout)
results = dict()
for _, responses in batch_responses:
for response in responses:
if response.id not in results:
results[response.id] = []
results[response.id].append(response)
if response.finished:
llm_logger.info(
f"Scheduler has pulled a finished response: {[response.id]}")
request_ids = list(results.keys())
for request_id in request_ids:
results[request_id] = sorted(
results[request_id], key=lambda response: f"{response.id}.{response.index}")
results[request_id] = [
result.raw for result in results[request_id]]
return results