mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
298 lines
12 KiB
Python
298 lines
12 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
|
|
from typing import List, Optional, Dict, Tuple
|
|
import time
|
|
from redis import ConnectionPool
|
|
from fastdeploy.scheduler.storage import AdaptedRedis
|
|
from fastdeploy.engine.request import Request, RequestOutput
|
|
from fastdeploy.metrics.metrics import main_process_metrics
|
|
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
|
from fastdeploy.scheduler.workers import Workers
|
|
from fastdeploy.utils import llm_logger
|
|
|
|
|
|
class GlobalScheduler(object):
|
|
"""
|
|
GlobalScheduler class
|
|
"""
|
|
|
|
def __init__(self,
|
|
host: str,
|
|
port: int,
|
|
db: int,
|
|
password: Optional[str],
|
|
topic: str,
|
|
ttl: int,
|
|
remote_write_time: int,
|
|
wait_response_timeout: float
|
|
):
|
|
|
|
self.topic = topic
|
|
self.ttl = ttl
|
|
self.remote_write_time = remote_write_time
|
|
self.wait_response_timeout = 1.0 if wait_response_timeout < 1.0 else wait_response_timeout
|
|
self.wait_request_timeout = 10
|
|
|
|
connection_pool = ConnectionPool(
|
|
host=host, port=port, db=db, password=password, max_connections=10)
|
|
self.client = AdaptedRedis(connection_pool=connection_pool)
|
|
|
|
self.put_request_workers = Workers(
|
|
"put_request_worker", self._put_requests_worker, max_batch_size=5)
|
|
self.put_request_workers.start(size=1)
|
|
|
|
self.put_response_workers = Workers(
|
|
"put_response_worker", self._put_results_worker, max_batch_size=50)
|
|
self.put_response_workers.start(size=1)
|
|
|
|
self.get_response_workers = Workers(
|
|
"get_response_worker", self._get_results_worker, max_batch_size=1)
|
|
self.get_response_workers.start(size=5)
|
|
self.response_max_batch = 50
|
|
|
|
llm_logger.info(f"Scheduler: redis version is {self.client.version}")
|
|
|
|
def _request_queue_name(self):
|
|
return f"{self.topic}.request"
|
|
|
|
def _response_queue_name(self, id: str):
|
|
return f"{self.topic}.response.{id}"
|
|
|
|
def _unique_key_name(self, id: str):
|
|
return f"{self.topic}.unique.{id}"
|
|
|
|
@staticmethod
|
|
def calc_required_blocks(token_num, block_size):
|
|
"""calculate required blocks for given token number"""
|
|
return (token_num + block_size - 1) // block_size
|
|
|
|
def _put_requests_worker(self, tasks: List[Tuple[str, Request]]) -> List[Tuple[str, Optional[str]]]:
|
|
"""
|
|
add requests to shared cache
|
|
"""
|
|
requests: List[ScheduledRequest] = [
|
|
ScheduledRequest(request) for _, request in tasks]
|
|
|
|
# check the uniqueness of the request_id
|
|
valid_requests: List[ScheduledRequest] = list()
|
|
duplicated_ids: List[str] = list()
|
|
for request in requests:
|
|
unique_key = self._unique_key_name(request.id)
|
|
if self.client.set(unique_key, "", ex=self.ttl, nx=True):
|
|
valid_requests.append(request)
|
|
else:
|
|
duplicated_ids.append(request.id)
|
|
|
|
# add to request queue
|
|
serialized_requests = [request.serialize()
|
|
for request in valid_requests]
|
|
self.client.rpush(self._request_queue_name(), *serialized_requests)
|
|
llm_logger.info(
|
|
f"Scheduler has put some requests: {[request.id for request in valid_requests]}")
|
|
main_process_metrics.num_requests_waiting.inc(len(valid_requests))
|
|
|
|
if len(duplicated_ids) > 0:
|
|
llm_logger.warning(
|
|
f"Scheduler has received some duplicated requests: {duplicated_ids}")
|
|
|
|
results = [(request.id, None) for request in valid_requests]
|
|
results += [(request_id, "duplicated request_id")
|
|
for request_id in duplicated_ids]
|
|
return results
|
|
|
|
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
|
"""
|
|
add requests to scheduler
|
|
"""
|
|
tasks: List[Tuple[str, Request]] = [
|
|
(request.request_id, request) for request in requests]
|
|
self.put_request_workers.put_tasks(tasks)
|
|
return self.put_request_workers.get_results(10, 0.005)
|
|
|
|
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
|
|
max_num_batched_tokens, batch=1) -> List[Request]:
|
|
"""
|
|
get requests blocked from shared cache
|
|
"""
|
|
|
|
if available_blocks <= reserved_output_blocks or batch < 1:
|
|
llm_logger.debug(
|
|
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
|
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
|
f"max_num_batched_tokens={max_num_batched_tokens}")
|
|
return []
|
|
|
|
batches = []
|
|
piece = (batch + 1) // 2
|
|
while batch > 0:
|
|
batch -= piece
|
|
if batch >= 0:
|
|
batches.append(piece)
|
|
else:
|
|
batches.append(piece + batch)
|
|
|
|
serialized_requests = []
|
|
for bs in batches:
|
|
bs_data = self.client.lpop(self._request_queue_name(), bs)
|
|
if bs_data is None:
|
|
break
|
|
serialized_requests += bs_data
|
|
|
|
if len(serialized_requests) == 0:
|
|
blocked_data = self.client.blpop(
|
|
self._request_queue_name(), self.wait_request_timeout)
|
|
if blocked_data is None:
|
|
return []
|
|
serialized_requests = blocked_data[1:]
|
|
|
|
required_total_blocks = 0
|
|
current_prefill_tokens = 0
|
|
remaining_request = []
|
|
requests: List[Request] = []
|
|
for serialized_request in serialized_requests:
|
|
if len(remaining_request) > 0:
|
|
remaining_request.append(serialized_request)
|
|
continue
|
|
|
|
request: ScheduledRequest = ScheduledRequest.unserialize(
|
|
serialized_request)
|
|
if (time.time() - request.scheduled_time) > self.ttl:
|
|
llm_logger.info(
|
|
f"Request has expired when getting a request from the scheduler: {[request.id]}")
|
|
continue
|
|
|
|
required_input_blocks = self.calc_required_blocks(
|
|
request.size, block_size)
|
|
current_prefill_tokens += request.size
|
|
required_total_blocks += required_input_blocks + reserved_output_blocks
|
|
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
|
remaining_request.append(serialized_request)
|
|
continue
|
|
requests.append(request.raw)
|
|
|
|
if len(remaining_request) > 0:
|
|
self.client.lpush(self._request_queue_name(), *remaining_request)
|
|
|
|
if len(requests) > 0:
|
|
llm_logger.info(
|
|
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
|
main_process_metrics.num_requests_running.inc(len(requests))
|
|
main_process_metrics.num_requests_waiting.dec(len(requests))
|
|
return requests
|
|
|
|
def _put_results_worker(self, tasks: List[Tuple[str, RequestOutput]]):
|
|
"""
|
|
add tasks to shared cache
|
|
"""
|
|
responses: List[ScheduledResponse] = [
|
|
ScheduledResponse(result) for _, result in tasks]
|
|
sorted_responses = sorted(
|
|
responses, key=lambda response: f"{response.id}.{response.index}")
|
|
|
|
finished_responses = [
|
|
response.id for response in responses if response.finished]
|
|
if len(finished_responses) > 0:
|
|
llm_logger.info(
|
|
f"Scheduler has received a finished response: {finished_responses}")
|
|
|
|
group = dict()
|
|
for response in sorted_responses:
|
|
serialized_response = response.serialize()
|
|
if response.id not in group:
|
|
group[response.id] = [serialized_response]
|
|
continue
|
|
group[response.id].append(serialized_response)
|
|
|
|
for response_id, responses in group.items():
|
|
ttl = self.client.ttl(self._unique_key_name(
|
|
response_id)) - self.remote_write_time
|
|
if ttl <= 0:
|
|
llm_logger.warning(
|
|
f"Scheduler has received a expired response: {[response.id]}")
|
|
continue
|
|
|
|
with self.client.pipeline() as pipe:
|
|
pipe.multi()
|
|
pipe.rpush(self._response_queue_name(response_id), *responses)
|
|
pipe.expire(self._response_queue_name(response_id), ttl)
|
|
pipe.execute()
|
|
|
|
def put_results(self, results: List[RequestOutput]):
|
|
"""
|
|
add results to shared cache
|
|
"""
|
|
tasks: List[Tuple[str, RequestOutput]] = [
|
|
(result.request_id, result) for result in results]
|
|
self.put_response_workers.put_tasks(tasks)
|
|
|
|
def _get_results_worker(self, tasks: List[Tuple[str, str]]) -> List[Tuple[str, List[ScheduledResponse]]]:
|
|
"""
|
|
get results blocked from shared cache
|
|
"""
|
|
if len(tasks) != 1:
|
|
raise ValueError(
|
|
f"Tasks size of _get_results_worker must be 1. ({len(tasks)})")
|
|
|
|
task_id, request_id = tasks[0]
|
|
key = self._response_queue_name(request_id)
|
|
size = self.client.llen(key)
|
|
size = min(size, self.response_max_batch)
|
|
|
|
serialized_responses = None
|
|
if size > 0:
|
|
serialized_responses = self.client.lpop(key, size)
|
|
|
|
if serialized_responses is None or len(serialized_responses) == 0:
|
|
blocked_data = self.client.blpop(key, self.wait_response_timeout)
|
|
if blocked_data is None:
|
|
return []
|
|
serialized_responses = blocked_data[1:]
|
|
|
|
output = [(task_id, [])]
|
|
for serialized_response in serialized_responses:
|
|
response = ScheduledResponse.unserialize(serialized_response)
|
|
output[0][1].append(response)
|
|
return output
|
|
|
|
def get_results(self, request_ids: List[str]) -> Dict[str, RequestOutput]:
|
|
"""
|
|
get results blocked from scheduler.
|
|
"""
|
|
tasks = [(request_id, request_id) for request_id in request_ids]
|
|
self.get_response_workers.put_tasks(tasks, deduplication=True)
|
|
batch_responses: List[Tuple[str, List[ScheduledResponse]]] = self.get_response_workers.get_results(
|
|
10, self.wait_response_timeout)
|
|
|
|
results = dict()
|
|
for _, responses in batch_responses:
|
|
for response in responses:
|
|
if response.id not in results:
|
|
results[response.id] = []
|
|
results[response.id].append(response)
|
|
if response.finished:
|
|
llm_logger.info(
|
|
f"Scheduler has pulled a finished response: {[response.id]}")
|
|
|
|
request_ids = list(results.keys())
|
|
for request_id in request_ids:
|
|
results[request_id] = sorted(
|
|
results[request_id], key=lambda response: f"{response.id}.{response.index}")
|
|
results[request_id] = [
|
|
result.raw for result in results[request_id]]
|
|
return results
|