Files
FastDeploy/fastdeploy/scheduler/splitwise_scheduler.py
2025-06-29 23:29:37 +00:00

836 lines
28 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import hashlib
import math
import random
import threading
import time
from collections import deque
from typing import List
import orjson
import redis
from fastdeploy.engine.request import (CompletionOutput, Request,
RequestMetrics, RequestOutput)
from fastdeploy.utils import scheduler_logger as logger
class SplitWiseSchedulerConfig(object):
"""SplitWise Scheduler Configuration"""
def __init__(
self,
nodeid=None,
host="127.0.0.1", # redis host
port=6379, # redis port
password=None, # redis password
topic="fd", # redis topic
ttl=900,
release_load_expire_period=600, #s
sync_period=5, #ms
expire_period=3000, #ms
clear_expired_nodes_period=60, #s
reader_parallel=4,
reader_batch_size=200,
writer_parallel=4,
writer_batch_size=200,
**kwargs):
if nodeid is None:
import uuid
nodeid = str(uuid.uuid4())
self.nodeid = nodeid
self.redis_host = host
self.redis_port = port
self.redis_password = password
self.redis_topic = topic
self.ttl = ttl
self.release_load_expire_period = release_load_expire_period
self.sync_period = sync_period
self.expire_period = expire_period / 1000.
self.clear_expired_nodes_period = clear_expired_nodes_period
self.reader_parallel = reader_parallel
self.reader_batch_size = reader_batch_size
self.writer_parallel = writer_parallel
self.writer_batch_size = writer_batch_size
def check(self):
"""check argument"""
pass
def print(self):
"""
print config
"""
logger.info("LocalScheduler Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info(
"=============================================================")
class SplitWiseScheduler(object):
"""
SplitWise Scheduler
"""
def __init__(self, config):
self.scheduler = APIScheduler(config)
self.infer = InferScheduler(config)
def start(self, role, host, disaggregated):
"""
Start APIScheduler and InferScheduler backup threads
"""
logger.info(
f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}"
)
self.infer.start(role, host, disaggregated)
self.scheduler.start()
def reset_nodeid(self, nodeid):
"""
reset node id
"""
self.scheduler.nodeid = nodeid
self.infer.nodeid = nodeid
def put_requests(self, reqs: List[Request]):
"""
put requests to global splitwise scheduler
"""
return self.scheduler.put_requests(reqs)
def get_results(self, request_ids=[]):
"""
get results from global splitwise scheduler
"""
return self.scheduler.get_results()
def get_requests(self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1):
"""
get scheduled requests from global spltiwise scheduler
"""
if available_blocks <= reserved_output_blocks or batch < 1:
logger.info(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
return []
return self.infer.get_requests(available_blocks, block_size,
reserved_output_blocks,
max_num_batched_tokens, batch)
def put_results(self, results: List[RequestOutput]):
"""
put results to global splitwise scheduler
"""
return self.infer.put_results(results)
class NodeInfo(object):
"""
Infer Node Info: load, rdma/ipc info
"""
@classmethod
def load_from(self, nodeid, info):
"""
load node info from seiralized string
"""
health = orjson.loads(info)
ts = health["ts"]
role = health["role"]
load = int(health["load"])
host = health["host"]
disaggregated = health["disaggregated"]
return NodeInfo(nodeid, role, host, disaggregated, load, ts)
def __init__(self, nodeid, role, host, disaggregated, load,
ts=time.time()):
self.nodeid = nodeid
self.ts = ts
self.host = host
self.disaggregated = disaggregated
self.role = role
self.lock = threading.Lock()
self.load = load
self.reqs = dict()
def __repr__(self):
return f"{self.nodeid}({self.load})"
def expired(self, expire_period):
"""
APIScheduler used to check if the node is expired
"""
now = time.time()
return (now - self.ts) > expire_period
def serialize(self):
"""
InferScheduler used to sync load
"""
self.ts = time.time()
health = {
"ts": self.ts,
"role": self.role,
"load": self.load,
"host": self.host,
"disaggregated": self.disaggregated
}
return orjson.dumps(health)
def __lt__(self, other):
return self.load < other.load
def expire_reqs(self, ttl):
"""
InferScheduler used to clear expired reqs
"""
cur_time = time.time()
with self.lock:
expire_reqs = set()
for req_id, pairs in self.reqs.items():
load, arrival_time = pairs
if cur_time - arrival_time > ttl:
logger.error(
f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})"
)
expire_reqs.add((req_id, load))
for req_id, load in expire_reqs:
if req_id in self.reqs:
self.load -= load
del self.reqs[req_id]
def add_req(self, req_id, load):
"""
InferScheduler used to record scheduled reqs(waiting or running)
"""
with self.lock:
if req_id not in self.reqs:
self.reqs[req_id] = [load, time.time()]
self.load += load
def update_req_timestamp(self, req_ids):
"""
InferScheduler used to update reqs timestamp
"""
cur_time = time.time()
with self.lock:
for req_id in req_ids:
if req_id in self.reqs:
self.reqs[req_id][1] = cur_time
def finish_req(self, req_id):
"""
InferScheduler used to clear finished reqs
"""
with self.lock:
if req_id in self.reqs:
load = self.reqs[req_id][0]
self.load -= load
del self.reqs[req_id]
class ResultReader(object):
"""
ResultReader use an async thread to continue get infer result from redis
"""
def __init__(self, client, idx, batch=200, ttl=900):
self.idx = idx
self.batch = batch
self.client = client
self.data = deque()
self.ttl = ttl
self.reqs = dict()
self.out_buffer = dict()
self.lock = threading.Lock()
self.thread = threading.Thread(target=self.run)
self.thread.start()
def add_req(self, req):
"""
add a req to reader, reader will async fetch infer result from redis
"""
with self.lock:
self.reqs[req.request_id] = {"arrival_time": req.arrival_time}
self.out_buffer[req.request_id] = []
def read(self):
"""
batch read infer results
returns: dict(req_id, [ResultOutput])
"""
items = []
size = len(self.data)
for i in range(size):
items.append(self.data.pop())
outputs = dict()
group_tokens = dict()
finish_reqs = set()
for item in items:
req_id = item.request_id
is_error = item.error_code != 200
if is_error or item.finished:
finish_reqs.add(req_id)
if is_error or item.outputs.send_idx == 0:
outputs[req_id] = [item]
continue
if req_id not in group_tokens:
group_tokens[req_id] = []
group_tokens[req_id].append(item)
with self.lock:
for key in finish_reqs:
if key in self.reqs:
del self.reqs[key]
for req_id, items in outputs.items():
if req_id in self.out_buffer:
items.extend(self.out_buffer[req_id])
del self.out_buffer[req_id]
for req_id, items in group_tokens.items():
if req_id in self.out_buffer:
self.out_buffer[req_id].extend(items)
continue
if req_id not in outputs:
outputs[req_id] = []
outputs[req_id].extend(items)
return outputs
def run(self):
"""
continue fetch infer results from redis
"""
while True:
try:
keys = []
cur_time = time.time()
with self.lock:
expired_reqs = set()
for req_id, req in self.reqs.items():
if cur_time - req.get("arrival_time",
cur_time) > self.ttl:
result = RequestOutput(
request_id=req_id,
prompt="",
prompt_token_ids=[],
outputs=CompletionOutput(-1, -1, []),
metrics=RequestMetrics(
arrival_time=req["arrival_time"]),
error_code=500,
error_msg=f"Req({req_id}) is expired({self.ttl})")
self.data.appendleft(result)
logger.error(
f"Req({req_id}) is expired({self.ttl})")
expired_reqs.add(req_id)
continue
keys.append(req_id)
for req_id in expired_reqs:
del self.reqs[req_id]
if len(keys) == 0:
time.sleep(0.01)
continue
total = self.sync_results(keys)
if total == 0:
time.sleep(0.01)
except Exception as e:
logger.error(
f"ResultsReader{self.idx} sync results error: {str(e)}")
def sync_results(self, keys):
"""
fetch infer results from redis for the give keys
"""
total = 0
for key in keys:
results = self.client.rpop(key, self.batch)
if results is None or len(results) == 0:
continue
#logger.info(f"Rpop {self.idx}: {len(results)}")
total += len(results)
for result in results:
try:
#logger.info(f"Scheduler Get Results: {result}")
data = orjson.loads(result)
result = RequestOutput.from_dict(data)
self.data.appendleft(result)
except Exception as e:
logger.error(f"Parse Result Error:{e}, {result}")
return total
class APIScheduler(object):
"""
APIScheduler: put requests to global schedule, and get recording infer results
"""
def __init__(self, config):
self.nodeid = config.nodeid
self.reader_parallel = config.reader_parallel
self.reader_batch_size = config.reader_batch_size
self.expire_period = config.expire_period
self.clear_expired_nodes_period = config.clear_expired_nodes_period
self.ttl = config.ttl
self.topic = config.redis_topic
self.cluster_key = f"{self.topic}.cluster"
self.client = redis.Redis(host=config.redis_host,
port=config.redis_port,
password=config.redis_password)
self.req_cond = threading.Condition()
self.reqs_queue = deque()
self.readers = []
def start(self):
"""
start backup threads
"""
for i in range(self.reader_parallel):
reader = ResultReader(self.client, i, self.reader_batch_size,
self.ttl)
self.readers.append(reader)
self.clear_expired_nodes_thread = threading.Thread(
target=self.loop_clear_expired_nodes)
self.clear_expired_nodes_thread.start()
self.schedule_thread = threading.Thread(target=self.loop_schedule)
self.schedule_thread.start()
def put_requests(self, reqs):
"""
put requests to local req queue. reqs will be async scheduled
"""
ret = []
with self.req_cond:
for req in reqs:
self.reqs_queue.appendleft(req)
ret.append((req.request_id, None))
self.req_cond.notify_all()
return ret
def get_results(self):
"""
get infer results from local queue. results is async fetched from redis
"""
outputs = dict()
for reader in self.readers:
outs = reader.read()
outputs.update(outs)
return outputs
def loop_schedule(self):
"""
loop schedule req based on global load states.
"""
reader_idx = 0
while True:
try:
with self.req_cond:
if len(self.reqs_queue) == 0:
self.req_cond.wait()
pnodes, dnodes, mnodes = self.sync_cluster()
if len(mnodes) == 0 and (len(pnodes) == 0 or len(dnodes) == 0):
logger.error(
f"No Schedule Nodes: mixed:{len(mnodes)}, prefill:{len(pnodes)}, decode:{len(dnodes)}"
)
time.sleep(1)
continue
req = self.reqs_queue.pop()
reader = self.readers[reader_idx]
reader.add_req(req)
reader_idx = (reader_idx + 1) % len(self.readers)
self.schedule(req, pnodes, dnodes, mnodes)
except IndexError:
continue
except Exception as e:
logger.error(f"APIScheduler Schedule req error: {str(e)}")
def schedule(self, req, pnodes, dnodes, mnodes):
"""
schedule an req to according redis node queue
"""
pnodes.extend(mnodes)
pnodes.sort()
pnode = self.select_pd(req, pnodes, "prefill")
if pnode.role == "mixed":
req.disaggregate_info = None
req_str = orjson.dumps(req.to_dict())
pkey = f"ReqQ_{pnode.nodeid}"
#logger.info(f"Schedule Req {req_str} to Mixed")
self.client.lpush(pkey, req_str)
else:
dnodes.sort()
dnode = self.select_pd(req, dnodes, "decode")
disaggregated = copy.deepcopy(dnode.disaggregated)
transfer_protocol = disaggregated["transfer_protocol"]
if len(
transfer_protocol
) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
if pnode.host == dnode.host:
disaggregated["transfer_protocol"] = "ipc"
else:
disaggregated["transfer_protocol"] = "rdma"
else:
disaggregated["transfer_protocol"] = transfer_protocol[0]
req.disaggregate_info = disaggregated
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_str = orjson.dumps(req.to_dict())
#logger.info(f"Schedule Req {req_str}")
self.client.lpush(dkey, req_str)
self.client.lpush(pkey, req_str)
def sync_cluster(self):
"""
fetch cluster load states from redis
"""
clusters = self.client.hgetall(self.cluster_key)
pnodes, dnodes, mnodes = [], [], []
for nodeid, info in clusters.items():
node = NodeInfo.load_from(nodeid.decode(), info)
if node.expired(self.expire_period):
logger.error(f"node {nodeid} is expired: {info}")
continue
if node.role == "prefill":
pnodes.append(node)
elif node.role == "decode":
dnodes.append(node)
elif node.role == "mixed":
mnodes.append(node)
else:
logger.error(f"Invalid Role: {node.role} {info}")
return pnodes, dnodes, mnodes
def loop_clear_expired_nodes(self):
"""
loop clear expired node's dirty data in redis
"""
while True:
try:
expire_nodes = set()
clusters = self.client.hgetall(self.cluster_key)
for nodeid, info in clusters.items():
node = NodeInfo.load_from(nodeid.decode(), info)
if node.expired(self.clear_expired_nodes_period):
expire_nodes.add(nodeid)
for nodeid in expire_nodes:
#logger.info(f"clear expired nodes: {nodeid}")
self.client.hdel(self.cluster_key, nodeid)
time.sleep(self.clear_expired_nodes_period)
except Exception:
logger.error(
"APIScheduler clear expired nodes error: {str(e)}")
def select_pd(self, req, nodes, role):
"""
select a prefill/decode/mixed node based on load states
"""
def select(req, nodes, blur_step):
min_load = nodes[0].load
blur_max = min_load + blur_step
blur_idx = 0
for idx, node in enumerate(nodes):
if node.load >= blur_max:
break
blur_idx = idx
node = random.choice(nodes[:blur_idx + 1])
logger.info(
f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}"
)
return node
if role == "prefill" or role == "mixed":
size = req.prompt_token_ids_len
rate = 2 if size < 1000 else 10
pblur_step = max(100, min(500, int(size / rate)))
pnode = select(req, nodes, pblur_step)
return pnode
elif role == "decode":
dblur_step = min(len(nodes), 10)
dnode = select(req, nodes, dblur_step)
return dnode
raise Exception(f"Invalid Role: {role}")
class ResultWriter(object):
"""
ResultWriter use an async thread to continue writer infer results to redis
"""
def __init__(self, client, idx, batch, ttl=900):
self.idx = idx
self.batch = batch
self.client = client
self.data = deque()
self.cond = threading.Condition()
self.thread = threading.Thread(target=self.run)
self.ttl = ttl
def start(self):
"""start backup thread"""
self.thread.start()
def put(self, key, items):
"""
put infer results to writer
"""
with self.cond:
for item in items:
self.data.appendleft((key, item))
self.cond.notify_all()
def run(self):
"""
continue batch write infer results to redis
"""
while True:
try:
with self.cond:
size = len(self.data)
if size == 0:
self.cond.wait()
size = min(size, self.batch)
groups = dict()
for i in range(size):
key, item = self.data.pop()
if key not in groups:
groups[key] = []
groups[key].append(item)
for key, items in groups.items():
#s = time.time()
with self.client.pipeline() as pipe:
pipe.multi()
pipe.lpush(key, *items)
pipe.expire(key, math.ceil(self.ttl))
pipe.execute()
#self.client.lpush(key, *items)
#e = time.time()
#logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
except Exception as e:
logger.error(f"ResultWriter write error: {str(e)}")
class InferScheduler(object):
"""
InferScheduler: get scheduled requests to local queue, write results to redis
"""
def __init__(self, config):
self.nodeid = config.nodeid
self.writer_parallel = config.writer_parallel
self.writer_batch_size = config.writer_batch_size
self.sync_period = config.sync_period
self.topic = config.redis_topic
self.cluster_key = f"{self.topic}.cluster"
self.ttl = config.ttl
self.release_load_expire_period = config.release_load_expire_period
self.client = redis.Redis(host=config.redis_host,
port=config.redis_port,
password=config.redis_password)
self.reqs_queue = deque()
self.writers = []
def start(self, role, host, disaggregated):
"""
start backup threads
"""
for i in range(self.writer_parallel):
writer = ResultWriter(self.client, i, self.writer_batch_size,
self.ttl)
writer.start()
self.writers.append(writer)
self.getreq_thread = threading.Thread(target=self.loop_get_reqs)
self.getreq_thread.start()
self.role = role
self.host = host
self.node = NodeInfo(self.nodeid, role, host, disaggregated, 0)
self.report_thread = threading.Thread(target=self.routine_report)
self.report_thread.start()
self.expire_reqs_thread = threading.Thread(
target=self.loop_expire_reqs)
self.expire_reqs_thread.start()
def routine_report(self):
"""
routine report node info: load, health
"""
while True:
try:
info = self.node.serialize()
self.client.hset(self.cluster_key, self.nodeid, info)
time.sleep(self.sync_period / 1000.)
except Exception as e:
logger.error(f"InferScheduler routine report error: {str(e)}")
def loop_expire_reqs(self):
"""
loop clear expired reqs
"""
while True:
try:
self.node.expire_reqs(self.release_load_expire_period)
time.sleep(60)
except Exception:
logger.error("InferScheduler expire reqs error: {e}")
def loop_get_reqs(self):
"""
loop get global scheduled reqs to local queue
"""
def select_writer(req):
req_id = req.request_id
md5 = hashlib.md5()
md5.update(req_id.encode())
writer_idx = int(md5.hexdigest(), 16) % len(self.writers)
return writer_idx
batch = 50
while True:
try:
key = f"ReqQ_{self.nodeid}"
reqs = self.client.rpop(key, batch)
if reqs is None:
ret = self.client.brpop([key], timeout=1)
if ret is None:
continue
reqs = [ret[1]]
for req_str in reqs:
req = orjson.loads(req_str)
req = Request.from_dict(req)
writer_idx = select_writer(req)
logger.info(
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
)
req.request_id = f"{req.request_id}#{writer_idx}"
if self.role == "prefill" or self.role == "mixed":
self.reqs_queue.append(req)
self.node.add_req(req.request_id,
req.prompt_token_ids_len)
else:
self.node.add_req(req.request_id, 1)
except Exception as e:
logger.error(f"InferScheduler loop get reqs error: {str(e)}")
def get_requests(self, available_blocks, block_size,
reserved_output_blocks, max_num_batched_tokens, batch):
"""
get scheduled reqs from local reqs queue
"""
if len(self.reqs_queue) == 0:
return []
reqs = []
required_blocks = 0
current_prefill_tokens = 0
cur_time = time.time()
for i in range(batch):
try:
req = self.reqs_queue.popleft()
if cur_time - req.arrival_time > self.ttl:
logger.error(
f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests"
)
self.node.finish_req(req.request_id)
continue
current_prefill_tokens += req.prompt_token_ids_len
required_input_blocks = (req.prompt_token_ids_len +
block_size - 1) // block_size
required_blocks += required_input_blocks + reserved_output_blocks
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
self.reqs_queue.appendleft(req)
return reqs
#logger.info(f"Get Requests from Scheduler: {req.request_id}")
reqs.append(req)
except Exception:
return reqs
return reqs
def put_results(self, results):
"""
put infer results to according writer's local queue
"""
groups = dict()
req_ids = set()
for result in results:
if result.error_code != 200 or result.finished:
self.node.finish_req(result.request_id)
logger.info(
f"{result.request_id} finished, node load is {self.node.load}"
)
req_ids.add(result.request_id)
req_id, idx = result.request_id.split("#")
result.request_id = req_id
key = (req_id, int(idx))
if key not in groups:
groups[key] = list()
if self.role == "prefill" and result.outputs.send_idx == 0:
result.finished = False
result_str = orjson.dumps(result.to_dict())
#if self.role == "prefill" or result.error_code != 200 or result.finished:
# logger.info(f"Infer Put Finish Result: {result_str}")
groups[key].append(result_str)
self.node.update_req_timestamp(req_ids)
for key, outputs in groups.items():
req_id, idx = key
self.writers[idx].put(req_id, outputs)