mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-29 05:42:27 +08:00
836 lines
28 KiB
Python
836 lines
28 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
import copy
|
|
import hashlib
|
|
import math
|
|
import random
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from typing import List
|
|
|
|
import orjson
|
|
import redis
|
|
|
|
from fastdeploy.engine.request import (CompletionOutput, Request,
|
|
RequestMetrics, RequestOutput)
|
|
from fastdeploy.utils import scheduler_logger as logger
|
|
|
|
|
|
class SplitWiseSchedulerConfig(object):
|
|
"""SplitWise Scheduler Configuration"""
|
|
|
|
def __init__(
|
|
self,
|
|
nodeid=None,
|
|
host="127.0.0.1", # redis host
|
|
port=6379, # redis port
|
|
password=None, # redis password
|
|
topic="fd", # redis topic
|
|
ttl=900,
|
|
release_load_expire_period=600, #s
|
|
sync_period=5, #ms
|
|
expire_period=3000, #ms
|
|
clear_expired_nodes_period=60, #s
|
|
reader_parallel=4,
|
|
reader_batch_size=200,
|
|
writer_parallel=4,
|
|
writer_batch_size=200,
|
|
**kwargs):
|
|
|
|
if nodeid is None:
|
|
import uuid
|
|
nodeid = str(uuid.uuid4())
|
|
self.nodeid = nodeid
|
|
|
|
self.redis_host = host
|
|
self.redis_port = port
|
|
self.redis_password = password
|
|
self.redis_topic = topic
|
|
self.ttl = ttl
|
|
self.release_load_expire_period = release_load_expire_period
|
|
|
|
self.sync_period = sync_period
|
|
self.expire_period = expire_period / 1000.
|
|
self.clear_expired_nodes_period = clear_expired_nodes_period
|
|
self.reader_parallel = reader_parallel
|
|
self.reader_batch_size = reader_batch_size
|
|
self.writer_parallel = writer_parallel
|
|
self.writer_batch_size = writer_batch_size
|
|
|
|
def check(self):
|
|
"""check argument"""
|
|
pass
|
|
|
|
def print(self):
|
|
"""
|
|
print config
|
|
"""
|
|
logger.info("LocalScheduler Configuration Information :")
|
|
for k, v in self.__dict__.items():
|
|
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
|
logger.info(
|
|
"=============================================================")
|
|
|
|
|
|
class SplitWiseScheduler(object):
|
|
"""
|
|
SplitWise Scheduler
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
self.scheduler = APIScheduler(config)
|
|
self.infer = InferScheduler(config)
|
|
|
|
def start(self, role, host, disaggregated):
|
|
"""
|
|
Start APIScheduler and InferScheduler backup threads
|
|
"""
|
|
logger.info(
|
|
f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}"
|
|
)
|
|
self.infer.start(role, host, disaggregated)
|
|
self.scheduler.start()
|
|
|
|
def reset_nodeid(self, nodeid):
|
|
"""
|
|
reset node id
|
|
"""
|
|
self.scheduler.nodeid = nodeid
|
|
self.infer.nodeid = nodeid
|
|
|
|
def put_requests(self, reqs: List[Request]):
|
|
"""
|
|
put requests to global splitwise scheduler
|
|
"""
|
|
return self.scheduler.put_requests(reqs)
|
|
|
|
def get_results(self, request_ids=[]):
|
|
"""
|
|
get results from global splitwise scheduler
|
|
"""
|
|
return self.scheduler.get_results()
|
|
|
|
def get_requests(self,
|
|
available_blocks,
|
|
block_size,
|
|
reserved_output_blocks,
|
|
max_num_batched_tokens,
|
|
batch=1):
|
|
"""
|
|
get scheduled requests from global spltiwise scheduler
|
|
"""
|
|
if available_blocks <= reserved_output_blocks or batch < 1:
|
|
logger.info(
|
|
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
|
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
|
f"max_num_batched_tokens={max_num_batched_tokens}")
|
|
return []
|
|
return self.infer.get_requests(available_blocks, block_size,
|
|
reserved_output_blocks,
|
|
max_num_batched_tokens, batch)
|
|
|
|
def put_results(self, results: List[RequestOutput]):
|
|
"""
|
|
put results to global splitwise scheduler
|
|
"""
|
|
return self.infer.put_results(results)
|
|
|
|
|
|
class NodeInfo(object):
|
|
"""
|
|
Infer Node Info: load, rdma/ipc info
|
|
"""
|
|
|
|
@classmethod
|
|
def load_from(self, nodeid, info):
|
|
"""
|
|
load node info from seiralized string
|
|
"""
|
|
health = orjson.loads(info)
|
|
ts = health["ts"]
|
|
role = health["role"]
|
|
load = int(health["load"])
|
|
host = health["host"]
|
|
disaggregated = health["disaggregated"]
|
|
return NodeInfo(nodeid, role, host, disaggregated, load, ts)
|
|
|
|
def __init__(self, nodeid, role, host, disaggregated, load,
|
|
ts=time.time()):
|
|
self.nodeid = nodeid
|
|
self.ts = ts
|
|
self.host = host
|
|
self.disaggregated = disaggregated
|
|
self.role = role
|
|
self.lock = threading.Lock()
|
|
self.load = load
|
|
self.reqs = dict()
|
|
|
|
def __repr__(self):
|
|
return f"{self.nodeid}({self.load})"
|
|
|
|
def expired(self, expire_period):
|
|
"""
|
|
APIScheduler used to check if the node is expired
|
|
"""
|
|
now = time.time()
|
|
return (now - self.ts) > expire_period
|
|
|
|
def serialize(self):
|
|
"""
|
|
InferScheduler used to sync load
|
|
"""
|
|
self.ts = time.time()
|
|
health = {
|
|
"ts": self.ts,
|
|
"role": self.role,
|
|
"load": self.load,
|
|
"host": self.host,
|
|
"disaggregated": self.disaggregated
|
|
}
|
|
return orjson.dumps(health)
|
|
|
|
def __lt__(self, other):
|
|
return self.load < other.load
|
|
|
|
def expire_reqs(self, ttl):
|
|
"""
|
|
InferScheduler used to clear expired reqs
|
|
"""
|
|
cur_time = time.time()
|
|
with self.lock:
|
|
expire_reqs = set()
|
|
for req_id, pairs in self.reqs.items():
|
|
load, arrival_time = pairs
|
|
if cur_time - arrival_time > ttl:
|
|
logger.error(
|
|
f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})"
|
|
)
|
|
expire_reqs.add((req_id, load))
|
|
for req_id, load in expire_reqs:
|
|
if req_id in self.reqs:
|
|
self.load -= load
|
|
del self.reqs[req_id]
|
|
|
|
def add_req(self, req_id, load):
|
|
"""
|
|
InferScheduler used to record scheduled reqs(waiting or running)
|
|
"""
|
|
with self.lock:
|
|
if req_id not in self.reqs:
|
|
self.reqs[req_id] = [load, time.time()]
|
|
self.load += load
|
|
|
|
def update_req_timestamp(self, req_ids):
|
|
"""
|
|
InferScheduler used to update reqs timestamp
|
|
"""
|
|
cur_time = time.time()
|
|
with self.lock:
|
|
for req_id in req_ids:
|
|
if req_id in self.reqs:
|
|
self.reqs[req_id][1] = cur_time
|
|
|
|
def finish_req(self, req_id):
|
|
"""
|
|
InferScheduler used to clear finished reqs
|
|
"""
|
|
with self.lock:
|
|
if req_id in self.reqs:
|
|
load = self.reqs[req_id][0]
|
|
self.load -= load
|
|
del self.reqs[req_id]
|
|
|
|
|
|
class ResultReader(object):
|
|
"""
|
|
ResultReader use an async thread to continue get infer result from redis
|
|
"""
|
|
|
|
def __init__(self, client, idx, batch=200, ttl=900):
|
|
self.idx = idx
|
|
self.batch = batch
|
|
self.client = client
|
|
self.data = deque()
|
|
self.ttl = ttl
|
|
|
|
self.reqs = dict()
|
|
self.out_buffer = dict()
|
|
self.lock = threading.Lock()
|
|
|
|
self.thread = threading.Thread(target=self.run)
|
|
self.thread.start()
|
|
|
|
def add_req(self, req):
|
|
"""
|
|
add a req to reader, reader will async fetch infer result from redis
|
|
"""
|
|
with self.lock:
|
|
self.reqs[req.request_id] = {"arrival_time": req.arrival_time}
|
|
self.out_buffer[req.request_id] = []
|
|
|
|
def read(self):
|
|
"""
|
|
batch read infer results
|
|
returns: dict(req_id, [ResultOutput])
|
|
"""
|
|
items = []
|
|
size = len(self.data)
|
|
for i in range(size):
|
|
items.append(self.data.pop())
|
|
|
|
outputs = dict()
|
|
group_tokens = dict()
|
|
finish_reqs = set()
|
|
for item in items:
|
|
req_id = item.request_id
|
|
|
|
is_error = item.error_code != 200
|
|
|
|
if is_error or item.finished:
|
|
finish_reqs.add(req_id)
|
|
|
|
if is_error or item.outputs.send_idx == 0:
|
|
outputs[req_id] = [item]
|
|
continue
|
|
|
|
if req_id not in group_tokens:
|
|
group_tokens[req_id] = []
|
|
group_tokens[req_id].append(item)
|
|
|
|
with self.lock:
|
|
for key in finish_reqs:
|
|
if key in self.reqs:
|
|
del self.reqs[key]
|
|
|
|
for req_id, items in outputs.items():
|
|
if req_id in self.out_buffer:
|
|
items.extend(self.out_buffer[req_id])
|
|
del self.out_buffer[req_id]
|
|
|
|
for req_id, items in group_tokens.items():
|
|
if req_id in self.out_buffer:
|
|
self.out_buffer[req_id].extend(items)
|
|
continue
|
|
|
|
if req_id not in outputs:
|
|
outputs[req_id] = []
|
|
outputs[req_id].extend(items)
|
|
|
|
return outputs
|
|
|
|
def run(self):
|
|
"""
|
|
continue fetch infer results from redis
|
|
"""
|
|
while True:
|
|
try:
|
|
keys = []
|
|
cur_time = time.time()
|
|
with self.lock:
|
|
expired_reqs = set()
|
|
for req_id, req in self.reqs.items():
|
|
if cur_time - req.get("arrival_time",
|
|
cur_time) > self.ttl:
|
|
result = RequestOutput(
|
|
request_id=req_id,
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=CompletionOutput(-1, -1, []),
|
|
metrics=RequestMetrics(
|
|
arrival_time=req["arrival_time"]),
|
|
error_code=500,
|
|
error_msg=f"Req({req_id}) is expired({self.ttl})")
|
|
self.data.appendleft(result)
|
|
|
|
logger.error(
|
|
f"Req({req_id}) is expired({self.ttl})")
|
|
expired_reqs.add(req_id)
|
|
continue
|
|
keys.append(req_id)
|
|
for req_id in expired_reqs:
|
|
del self.reqs[req_id]
|
|
|
|
if len(keys) == 0:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
total = self.sync_results(keys)
|
|
if total == 0:
|
|
time.sleep(0.01)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"ResultsReader{self.idx} sync results error: {str(e)}")
|
|
|
|
def sync_results(self, keys):
|
|
"""
|
|
fetch infer results from redis for the give keys
|
|
"""
|
|
total = 0
|
|
for key in keys:
|
|
results = self.client.rpop(key, self.batch)
|
|
if results is None or len(results) == 0:
|
|
continue
|
|
#logger.info(f"Rpop {self.idx}: {len(results)}")
|
|
total += len(results)
|
|
for result in results:
|
|
try:
|
|
#logger.info(f"Scheduler Get Results: {result}")
|
|
data = orjson.loads(result)
|
|
result = RequestOutput.from_dict(data)
|
|
self.data.appendleft(result)
|
|
except Exception as e:
|
|
logger.error(f"Parse Result Error:{e}, {result}")
|
|
return total
|
|
|
|
|
|
class APIScheduler(object):
|
|
"""
|
|
APIScheduler: put requests to global schedule, and get recording infer results
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
self.nodeid = config.nodeid
|
|
self.reader_parallel = config.reader_parallel
|
|
self.reader_batch_size = config.reader_batch_size
|
|
self.expire_period = config.expire_period
|
|
self.clear_expired_nodes_period = config.clear_expired_nodes_period
|
|
self.ttl = config.ttl
|
|
self.topic = config.redis_topic
|
|
self.cluster_key = f"{self.topic}.cluster"
|
|
|
|
self.client = redis.Redis(host=config.redis_host,
|
|
port=config.redis_port,
|
|
password=config.redis_password)
|
|
|
|
self.req_cond = threading.Condition()
|
|
self.reqs_queue = deque()
|
|
self.readers = []
|
|
|
|
def start(self):
|
|
"""
|
|
start backup threads
|
|
"""
|
|
for i in range(self.reader_parallel):
|
|
reader = ResultReader(self.client, i, self.reader_batch_size,
|
|
self.ttl)
|
|
self.readers.append(reader)
|
|
|
|
self.clear_expired_nodes_thread = threading.Thread(
|
|
target=self.loop_clear_expired_nodes)
|
|
self.clear_expired_nodes_thread.start()
|
|
|
|
self.schedule_thread = threading.Thread(target=self.loop_schedule)
|
|
self.schedule_thread.start()
|
|
|
|
def put_requests(self, reqs):
|
|
"""
|
|
put requests to local req queue. reqs will be async scheduled
|
|
"""
|
|
ret = []
|
|
with self.req_cond:
|
|
for req in reqs:
|
|
self.reqs_queue.appendleft(req)
|
|
ret.append((req.request_id, None))
|
|
self.req_cond.notify_all()
|
|
return ret
|
|
|
|
def get_results(self):
|
|
"""
|
|
get infer results from local queue. results is async fetched from redis
|
|
"""
|
|
outputs = dict()
|
|
for reader in self.readers:
|
|
outs = reader.read()
|
|
outputs.update(outs)
|
|
return outputs
|
|
|
|
def loop_schedule(self):
|
|
"""
|
|
loop schedule req based on global load states.
|
|
"""
|
|
reader_idx = 0
|
|
while True:
|
|
try:
|
|
with self.req_cond:
|
|
if len(self.reqs_queue) == 0:
|
|
self.req_cond.wait()
|
|
|
|
pnodes, dnodes, mnodes = self.sync_cluster()
|
|
if len(mnodes) == 0 and (len(pnodes) == 0 or len(dnodes) == 0):
|
|
logger.error(
|
|
f"No Schedule Nodes: mixed:{len(mnodes)}, prefill:{len(pnodes)}, decode:{len(dnodes)}"
|
|
)
|
|
time.sleep(1)
|
|
continue
|
|
|
|
req = self.reqs_queue.pop()
|
|
|
|
reader = self.readers[reader_idx]
|
|
reader.add_req(req)
|
|
reader_idx = (reader_idx + 1) % len(self.readers)
|
|
|
|
self.schedule(req, pnodes, dnodes, mnodes)
|
|
except IndexError:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"APIScheduler Schedule req error: {str(e)}")
|
|
|
|
def schedule(self, req, pnodes, dnodes, mnodes):
|
|
"""
|
|
schedule an req to according redis node queue
|
|
"""
|
|
pnodes.extend(mnodes)
|
|
pnodes.sort()
|
|
pnode = self.select_pd(req, pnodes, "prefill")
|
|
if pnode.role == "mixed":
|
|
req.disaggregate_info = None
|
|
req_str = orjson.dumps(req.to_dict())
|
|
pkey = f"ReqQ_{pnode.nodeid}"
|
|
#logger.info(f"Schedule Req {req_str} to Mixed")
|
|
self.client.lpush(pkey, req_str)
|
|
else:
|
|
dnodes.sort()
|
|
dnode = self.select_pd(req, dnodes, "decode")
|
|
disaggregated = copy.deepcopy(dnode.disaggregated)
|
|
transfer_protocol = disaggregated["transfer_protocol"]
|
|
if len(
|
|
transfer_protocol
|
|
) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
|
|
if pnode.host == dnode.host:
|
|
disaggregated["transfer_protocol"] = "ipc"
|
|
else:
|
|
disaggregated["transfer_protocol"] = "rdma"
|
|
else:
|
|
disaggregated["transfer_protocol"] = transfer_protocol[0]
|
|
req.disaggregate_info = disaggregated
|
|
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
|
|
req_str = orjson.dumps(req.to_dict())
|
|
#logger.info(f"Schedule Req {req_str}")
|
|
self.client.lpush(dkey, req_str)
|
|
self.client.lpush(pkey, req_str)
|
|
|
|
def sync_cluster(self):
|
|
"""
|
|
fetch cluster load states from redis
|
|
"""
|
|
clusters = self.client.hgetall(self.cluster_key)
|
|
pnodes, dnodes, mnodes = [], [], []
|
|
for nodeid, info in clusters.items():
|
|
node = NodeInfo.load_from(nodeid.decode(), info)
|
|
if node.expired(self.expire_period):
|
|
logger.error(f"node {nodeid} is expired: {info}")
|
|
continue
|
|
if node.role == "prefill":
|
|
pnodes.append(node)
|
|
elif node.role == "decode":
|
|
dnodes.append(node)
|
|
elif node.role == "mixed":
|
|
mnodes.append(node)
|
|
else:
|
|
logger.error(f"Invalid Role: {node.role} {info}")
|
|
return pnodes, dnodes, mnodes
|
|
|
|
def loop_clear_expired_nodes(self):
|
|
"""
|
|
loop clear expired node's dirty data in redis
|
|
"""
|
|
while True:
|
|
try:
|
|
expire_nodes = set()
|
|
clusters = self.client.hgetall(self.cluster_key)
|
|
for nodeid, info in clusters.items():
|
|
node = NodeInfo.load_from(nodeid.decode(), info)
|
|
if node.expired(self.clear_expired_nodes_period):
|
|
expire_nodes.add(nodeid)
|
|
for nodeid in expire_nodes:
|
|
#logger.info(f"clear expired nodes: {nodeid}")
|
|
self.client.hdel(self.cluster_key, nodeid)
|
|
time.sleep(self.clear_expired_nodes_period)
|
|
except Exception:
|
|
logger.error(
|
|
"APIScheduler clear expired nodes error: {str(e)}")
|
|
|
|
def select_pd(self, req, nodes, role):
|
|
"""
|
|
select a prefill/decode/mixed node based on load states
|
|
"""
|
|
|
|
def select(req, nodes, blur_step):
|
|
min_load = nodes[0].load
|
|
blur_max = min_load + blur_step
|
|
blur_idx = 0
|
|
for idx, node in enumerate(nodes):
|
|
if node.load >= blur_max:
|
|
break
|
|
blur_idx = idx
|
|
node = random.choice(nodes[:blur_idx + 1])
|
|
logger.info(
|
|
f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}"
|
|
)
|
|
return node
|
|
|
|
if role == "prefill" or role == "mixed":
|
|
size = req.prompt_token_ids_len
|
|
rate = 2 if size < 1000 else 10
|
|
pblur_step = max(100, min(500, int(size / rate)))
|
|
pnode = select(req, nodes, pblur_step)
|
|
return pnode
|
|
elif role == "decode":
|
|
dblur_step = min(len(nodes), 10)
|
|
dnode = select(req, nodes, dblur_step)
|
|
return dnode
|
|
|
|
raise Exception(f"Invalid Role: {role}")
|
|
|
|
|
|
class ResultWriter(object):
|
|
"""
|
|
ResultWriter use an async thread to continue writer infer results to redis
|
|
"""
|
|
|
|
def __init__(self, client, idx, batch, ttl=900):
|
|
self.idx = idx
|
|
self.batch = batch
|
|
self.client = client
|
|
self.data = deque()
|
|
self.cond = threading.Condition()
|
|
self.thread = threading.Thread(target=self.run)
|
|
self.ttl = ttl
|
|
|
|
def start(self):
|
|
"""start backup thread"""
|
|
self.thread.start()
|
|
|
|
def put(self, key, items):
|
|
"""
|
|
put infer results to writer
|
|
"""
|
|
with self.cond:
|
|
for item in items:
|
|
self.data.appendleft((key, item))
|
|
self.cond.notify_all()
|
|
|
|
def run(self):
|
|
"""
|
|
continue batch write infer results to redis
|
|
"""
|
|
while True:
|
|
try:
|
|
with self.cond:
|
|
size = len(self.data)
|
|
if size == 0:
|
|
self.cond.wait()
|
|
size = min(size, self.batch)
|
|
groups = dict()
|
|
for i in range(size):
|
|
key, item = self.data.pop()
|
|
if key not in groups:
|
|
groups[key] = []
|
|
groups[key].append(item)
|
|
for key, items in groups.items():
|
|
#s = time.time()
|
|
with self.client.pipeline() as pipe:
|
|
pipe.multi()
|
|
pipe.lpush(key, *items)
|
|
pipe.expire(key, math.ceil(self.ttl))
|
|
pipe.execute()
|
|
#self.client.lpush(key, *items)
|
|
#e = time.time()
|
|
#logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
|
|
except Exception as e:
|
|
logger.error(f"ResultWriter write error: {str(e)}")
|
|
|
|
|
|
class InferScheduler(object):
|
|
"""
|
|
InferScheduler: get scheduled requests to local queue, write results to redis
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
self.nodeid = config.nodeid
|
|
self.writer_parallel = config.writer_parallel
|
|
self.writer_batch_size = config.writer_batch_size
|
|
self.sync_period = config.sync_period
|
|
self.topic = config.redis_topic
|
|
self.cluster_key = f"{self.topic}.cluster"
|
|
self.ttl = config.ttl
|
|
self.release_load_expire_period = config.release_load_expire_period
|
|
|
|
self.client = redis.Redis(host=config.redis_host,
|
|
port=config.redis_port,
|
|
password=config.redis_password)
|
|
|
|
self.reqs_queue = deque()
|
|
self.writers = []
|
|
|
|
def start(self, role, host, disaggregated):
|
|
"""
|
|
start backup threads
|
|
"""
|
|
for i in range(self.writer_parallel):
|
|
writer = ResultWriter(self.client, i, self.writer_batch_size,
|
|
self.ttl)
|
|
writer.start()
|
|
self.writers.append(writer)
|
|
|
|
self.getreq_thread = threading.Thread(target=self.loop_get_reqs)
|
|
self.getreq_thread.start()
|
|
|
|
self.role = role
|
|
self.host = host
|
|
self.node = NodeInfo(self.nodeid, role, host, disaggregated, 0)
|
|
|
|
self.report_thread = threading.Thread(target=self.routine_report)
|
|
self.report_thread.start()
|
|
|
|
self.expire_reqs_thread = threading.Thread(
|
|
target=self.loop_expire_reqs)
|
|
self.expire_reqs_thread.start()
|
|
|
|
def routine_report(self):
|
|
"""
|
|
routine report node info: load, health
|
|
"""
|
|
while True:
|
|
try:
|
|
info = self.node.serialize()
|
|
self.client.hset(self.cluster_key, self.nodeid, info)
|
|
time.sleep(self.sync_period / 1000.)
|
|
except Exception as e:
|
|
logger.error(f"InferScheduler routine report error: {str(e)}")
|
|
|
|
def loop_expire_reqs(self):
|
|
"""
|
|
loop clear expired reqs
|
|
"""
|
|
while True:
|
|
try:
|
|
self.node.expire_reqs(self.release_load_expire_period)
|
|
time.sleep(60)
|
|
except Exception:
|
|
logger.error("InferScheduler expire reqs error: {e}")
|
|
|
|
def loop_get_reqs(self):
|
|
"""
|
|
loop get global scheduled reqs to local queue
|
|
"""
|
|
|
|
def select_writer(req):
|
|
req_id = req.request_id
|
|
md5 = hashlib.md5()
|
|
md5.update(req_id.encode())
|
|
writer_idx = int(md5.hexdigest(), 16) % len(self.writers)
|
|
return writer_idx
|
|
|
|
batch = 50
|
|
while True:
|
|
try:
|
|
key = f"ReqQ_{self.nodeid}"
|
|
reqs = self.client.rpop(key, batch)
|
|
if reqs is None:
|
|
ret = self.client.brpop([key], timeout=1)
|
|
if ret is None:
|
|
continue
|
|
reqs = [ret[1]]
|
|
|
|
for req_str in reqs:
|
|
req = orjson.loads(req_str)
|
|
req = Request.from_dict(req)
|
|
writer_idx = select_writer(req)
|
|
logger.info(
|
|
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
|
|
)
|
|
req.request_id = f"{req.request_id}#{writer_idx}"
|
|
if self.role == "prefill" or self.role == "mixed":
|
|
self.reqs_queue.append(req)
|
|
self.node.add_req(req.request_id,
|
|
req.prompt_token_ids_len)
|
|
else:
|
|
self.node.add_req(req.request_id, 1)
|
|
except Exception as e:
|
|
logger.error(f"InferScheduler loop get reqs error: {str(e)}")
|
|
|
|
def get_requests(self, available_blocks, block_size,
|
|
reserved_output_blocks, max_num_batched_tokens, batch):
|
|
"""
|
|
get scheduled reqs from local reqs queue
|
|
"""
|
|
if len(self.reqs_queue) == 0:
|
|
return []
|
|
|
|
reqs = []
|
|
required_blocks = 0
|
|
current_prefill_tokens = 0
|
|
cur_time = time.time()
|
|
for i in range(batch):
|
|
try:
|
|
req = self.reqs_queue.popleft()
|
|
if cur_time - req.arrival_time > self.ttl:
|
|
logger.error(
|
|
f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests"
|
|
)
|
|
self.node.finish_req(req.request_id)
|
|
continue
|
|
current_prefill_tokens += req.prompt_token_ids_len
|
|
required_input_blocks = (req.prompt_token_ids_len +
|
|
block_size - 1) // block_size
|
|
required_blocks += required_input_blocks + reserved_output_blocks
|
|
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
|
self.reqs_queue.appendleft(req)
|
|
return reqs
|
|
#logger.info(f"Get Requests from Scheduler: {req.request_id}")
|
|
reqs.append(req)
|
|
except Exception:
|
|
return reqs
|
|
return reqs
|
|
|
|
def put_results(self, results):
|
|
"""
|
|
put infer results to according writer's local queue
|
|
"""
|
|
groups = dict()
|
|
req_ids = set()
|
|
for result in results:
|
|
if result.error_code != 200 or result.finished:
|
|
self.node.finish_req(result.request_id)
|
|
logger.info(
|
|
f"{result.request_id} finished, node load is {self.node.load}"
|
|
)
|
|
|
|
req_ids.add(result.request_id)
|
|
|
|
req_id, idx = result.request_id.split("#")
|
|
result.request_id = req_id
|
|
|
|
key = (req_id, int(idx))
|
|
if key not in groups:
|
|
groups[key] = list()
|
|
|
|
if self.role == "prefill" and result.outputs.send_idx == 0:
|
|
result.finished = False
|
|
|
|
result_str = orjson.dumps(result.to_dict())
|
|
#if self.role == "prefill" or result.error_code != 200 or result.finished:
|
|
# logger.info(f"Infer Put Finish Result: {result_str}")
|
|
groups[key].append(result_str)
|
|
|
|
self.node.update_req_timestamp(req_ids)
|
|
|
|
for key, outputs in groups.items():
|
|
req_id, idx = key
|
|
self.writers[idx].put(req_id, outputs)
|