polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import hashlib
import math
@@ -25,34 +26,40 @@ from typing import List
import orjson
import redis
from fastdeploy.engine.request import (CompletionOutput, Request,
RequestMetrics, RequestOutput)
from fastdeploy.engine.request import (
CompletionOutput,
Request,
RequestMetrics,
RequestOutput,
)
from fastdeploy.utils import scheduler_logger as logger
class SplitWiseSchedulerConfig(object):
class SplitWiseSchedulerConfig:
"""SplitWise Scheduler Configuration"""
def __init__(
self,
nodeid=None,
host="127.0.0.1", # redis host
port=6379, # redis port
password=None, # redis password
topic="fd", # redis topic
ttl=900,
release_load_expire_period=600, #s
sync_period=5, #ms
expire_period=3000, #ms
clear_expired_nodes_period=60, #s
reader_parallel=4,
reader_batch_size=200,
writer_parallel=4,
writer_batch_size=200,
**kwargs):
self,
nodeid=None,
host="127.0.0.1", # redis host
port=6379, # redis port
password=None, # redis password
topic="fd", # redis topic
ttl=900,
release_load_expire_period=600, # s
sync_period=5, # ms
expire_period=3000, # ms
clear_expired_nodes_period=60, # s
reader_parallel=4,
reader_batch_size=200,
writer_parallel=4,
writer_batch_size=200,
**kwargs,
):
if nodeid is None:
import uuid
nodeid = str(uuid.uuid4())
self.nodeid = nodeid
@@ -64,7 +71,7 @@ class SplitWiseSchedulerConfig(object):
self.release_load_expire_period = release_load_expire_period
self.sync_period = sync_period
self.expire_period = expire_period / 1000.
self.expire_period = expire_period / 1000.0
self.clear_expired_nodes_period = clear_expired_nodes_period
self.reader_parallel = reader_parallel
self.reader_batch_size = reader_batch_size
@@ -82,13 +89,12 @@ class SplitWiseSchedulerConfig(object):
logger.info("LocalScheduler Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info(
"=============================================================")
logger.info("=============================================================")
class SplitWiseScheduler(object):
class SplitWiseScheduler:
"""
SplitWise Scheduler
SplitWise Scheduler
"""
def __init__(self, config):
@@ -97,68 +103,73 @@ class SplitWiseScheduler(object):
def start(self, role, host, disaggregated):
"""
Start APIScheduler and InferScheduler backup threads
Start APIScheduler and InferScheduler backup threads
"""
logger.info(
f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}"
)
logger.info(f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}")
self.infer.start(role, host, disaggregated)
self.scheduler.start()
def reset_nodeid(self, nodeid):
"""
reset node id
reset node id
"""
self.scheduler.nodeid = nodeid
self.infer.nodeid = nodeid
def put_requests(self, reqs: List[Request]):
"""
put requests to global splitwise scheduler
put requests to global splitwise scheduler
"""
return self.scheduler.put_requests(reqs)
def get_results(self, request_ids=[]):
"""
get results from global splitwise scheduler
get results from global splitwise scheduler
"""
return self.scheduler.get_results()
def get_requests(self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1):
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
):
"""
get scheduled requests from global spltiwise scheduler
get scheduled requests from global spltiwise scheduler
"""
if available_blocks <= reserved_output_blocks or batch < 1:
logger.info(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
return self.infer.get_requests(available_blocks, block_size,
reserved_output_blocks,
max_num_batched_tokens, batch)
return self.infer.get_requests(
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch,
)
def put_results(self, results: List[RequestOutput]):
"""
put results to global splitwise scheduler
put results to global splitwise scheduler
"""
return self.infer.put_results(results)
class NodeInfo(object):
class NodeInfo:
"""
Infer Node Info: load, rdma/ipc info
Infer Node Info: load, rdma/ipc info
"""
@classmethod
def load_from(self, nodeid, info):
"""
load node info from seiralized string
load node info from seiralized string
"""
health = orjson.loads(info)
ts = health["ts"]
@@ -168,8 +179,7 @@ class NodeInfo(object):
disaggregated = health["disaggregated"]
return NodeInfo(nodeid, role, host, disaggregated, load, ts)
def __init__(self, nodeid, role, host, disaggregated, load,
ts=time.time()):
def __init__(self, nodeid, role, host, disaggregated, load, ts=time.time()):
self.nodeid = nodeid
self.ts = ts
self.host = host
@@ -184,14 +194,14 @@ class NodeInfo(object):
def expired(self, expire_period):
"""
APIScheduler used to check if the node is expired
APIScheduler used to check if the node is expired
"""
now = time.time()
return (now - self.ts) > expire_period
def serialize(self):
"""
InferScheduler used to sync load
InferScheduler used to sync load
"""
self.ts = time.time()
health = {
@@ -199,7 +209,7 @@ class NodeInfo(object):
"role": self.role,
"load": self.load,
"host": self.host,
"disaggregated": self.disaggregated
"disaggregated": self.disaggregated,
}
return orjson.dumps(health)
@@ -208,7 +218,7 @@ class NodeInfo(object):
def expire_reqs(self, ttl):
"""
InferScheduler used to clear expired reqs
InferScheduler used to clear expired reqs
"""
cur_time = time.time()
with self.lock:
@@ -216,9 +226,7 @@ class NodeInfo(object):
for req_id, pairs in self.reqs.items():
load, arrival_time = pairs
if cur_time - arrival_time > ttl:
logger.error(
f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})"
)
logger.error(f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})")
expire_reqs.add((req_id, load))
for req_id, load in expire_reqs:
if req_id in self.reqs:
@@ -227,7 +235,7 @@ class NodeInfo(object):
def add_req(self, req_id, load):
"""
InferScheduler used to record scheduled reqs(waiting or running)
InferScheduler used to record scheduled reqs(waiting or running)
"""
with self.lock:
if req_id not in self.reqs:
@@ -236,7 +244,7 @@ class NodeInfo(object):
def update_req_timestamp(self, req_ids):
"""
InferScheduler used to update reqs timestamp
InferScheduler used to update reqs timestamp
"""
cur_time = time.time()
with self.lock:
@@ -246,7 +254,7 @@ class NodeInfo(object):
def finish_req(self, req_id):
"""
InferScheduler used to clear finished reqs
InferScheduler used to clear finished reqs
"""
with self.lock:
if req_id in self.reqs:
@@ -255,9 +263,9 @@ class NodeInfo(object):
del self.reqs[req_id]
class ResultReader(object):
class ResultReader:
"""
ResultReader use an async thread to continue get infer result from redis
ResultReader use an async thread to continue get infer result from redis
"""
def __init__(self, client, idx, batch=200, ttl=900, group=""):
@@ -277,7 +285,7 @@ class ResultReader(object):
def add_req(self, req):
"""
add a req to reader, reader will async fetch infer result from redis
add a req to reader, reader will async fetch infer result from redis
"""
with self.lock:
self.reqs[req.request_id] = {"arrival_time": req.arrival_time}
@@ -285,8 +293,8 @@ class ResultReader(object):
def read(self):
"""
batch read infer results
returns: dict(req_id, [ResultOutput])
batch read infer results
returns: dict(req_id, [ResultOutput])
"""
items = []
size = len(self.data)
@@ -335,7 +343,7 @@ class ResultReader(object):
def run(self):
"""
continue fetch infer results from redis
continue fetch infer results from redis
"""
while True:
try:
@@ -344,21 +352,19 @@ class ResultReader(object):
with self.lock:
expired_reqs = set()
for req_id, req in self.reqs.items():
if cur_time - req.get("arrival_time",
cur_time) > self.ttl:
if cur_time - req.get("arrival_time", cur_time) > self.ttl:
result = RequestOutput(
request_id=req_id,
prompt="",
prompt_token_ids=[],
outputs=CompletionOutput(-1, -1, []),
metrics=RequestMetrics(
arrival_time=req["arrival_time"]),
metrics=RequestMetrics(arrival_time=req["arrival_time"]),
error_code=500,
error_msg=f"Req({req_id}) is expired({self.ttl})")
error_msg=f"Req({req_id}) is expired({self.ttl})",
)
self.data.appendleft(result)
logger.error(
f"Req({req_id}) is expired({self.ttl})")
logger.error(f"Req({req_id}) is expired({self.ttl})")
expired_reqs.add(req_id)
continue
keys.append(req_id)
@@ -373,22 +379,21 @@ class ResultReader(object):
if total == 0:
time.sleep(0.01)
except Exception as e:
logger.error(
f"ResultsReader{self.idx} sync results error: {str(e)}")
logger.error(f"ResultsReader{self.idx} sync results error: {e!s}")
def sync_results(self, keys):
"""
fetch infer results from redis for the give keys
fetch infer results from redis for the give keys
"""
total = 0
if self.group != "":
keys = [self.group]
for key in keys:
#logger.info(f"Sync Results from Redis {key}")
# logger.info(f"Sync Results from Redis {key}")
results = self.client.rpop(key, self.batch)
if results is None or len(results) == 0:
continue
#logger.info(f"Rpop {key} {self.idx}: {len(results)}")
# logger.info(f"Rpop {key} {self.idx}: {len(results)}")
total += len(results)
for result in results:
try:
@@ -401,9 +406,9 @@ class ResultReader(object):
return total
class APIScheduler(object):
class APIScheduler:
"""
APIScheduler: put requests to global schedule, and get recording infer results
APIScheduler: put requests to global schedule, and get recording infer results
"""
def __init__(self, config):
@@ -416,9 +421,11 @@ class APIScheduler(object):
self.topic = config.redis_topic
self.cluster_key = f"{self.topic}.cluster"
self.client = redis.Redis(host=config.redis_host,
port=config.redis_port,
password=config.redis_password)
self.client = redis.Redis(
host=config.redis_host,
port=config.redis_port,
password=config.redis_password,
)
self.req_cond = threading.Condition()
self.reqs_queue = deque()
@@ -426,16 +433,14 @@ class APIScheduler(object):
def start(self):
"""
start backup threads
start backup threads
"""
for i in range(self.reader_parallel):
group = f"{self.nodeid}-{i}"
reader = ResultReader(self.client, i, self.reader_batch_size,
self.ttl, group)
reader = ResultReader(self.client, i, self.reader_batch_size, self.ttl, group)
self.readers.append(reader)
self.clear_expired_nodes_thread = threading.Thread(
target=self.loop_clear_expired_nodes)
self.clear_expired_nodes_thread = threading.Thread(target=self.loop_clear_expired_nodes)
self.clear_expired_nodes_thread.start()
self.schedule_thread = threading.Thread(target=self.loop_schedule)
@@ -443,7 +448,7 @@ class APIScheduler(object):
def put_requests(self, reqs):
"""
put requests to local req queue. reqs will be async scheduled
put requests to local req queue. reqs will be async scheduled
"""
ret = []
with self.req_cond:
@@ -455,7 +460,7 @@ class APIScheduler(object):
def get_results(self):
"""
get infer results from local queue. results is async fetched from redis
get infer results from local queue. results is async fetched from redis
"""
outputs = dict()
for reader in self.readers:
@@ -465,7 +470,7 @@ class APIScheduler(object):
def loop_schedule(self):
"""
loop schedule req based on global load states.
loop schedule req based on global load states.
"""
reader_idx = 0
while True:
@@ -493,11 +498,11 @@ class APIScheduler(object):
except IndexError:
continue
except Exception as e:
logger.error(f"APIScheduler Schedule req error: {str(e)}")
logger.error(f"APIScheduler Schedule req error: {e!s}")
def schedule(self, req, pnodes, dnodes, mnodes, group=""):
"""
schedule an req to according redis node queue
schedule an req to according redis node queue
"""
pnodes.extend(mnodes)
pnodes.sort()
@@ -508,16 +513,14 @@ class APIScheduler(object):
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
pkey = f"ReqQ_{pnode.nodeid}"
#logger.info(f"Schedule Req {req_str} to Mixed")
# logger.info(f"Schedule Req {req_str} to Mixed")
self.client.lpush(pkey, req_str)
else:
dnodes.sort()
dnode = self.select_pd(req, dnodes, "decode")
disaggregated = copy.deepcopy(dnode.disaggregated)
transfer_protocol = disaggregated["transfer_protocol"]
if len(
transfer_protocol
) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
if len(transfer_protocol) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
if pnode.host == dnode.host:
disaggregated["transfer_protocol"] = "ipc"
else:
@@ -529,13 +532,13 @@ class APIScheduler(object):
req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
#logger.info(f"Schedule Req {req_str}")
# logger.info(f"Schedule Req {req_str}")
self.client.lpush(dkey, req_str)
self.client.lpush(pkey, req_str)
def sync_cluster(self):
"""
fetch cluster load states from redis
fetch cluster load states from redis
"""
clusters = self.client.hgetall(self.cluster_key)
pnodes, dnodes, mnodes = [], [], []
@@ -556,7 +559,7 @@ class APIScheduler(object):
def loop_clear_expired_nodes(self):
"""
loop clear expired node's dirty data in redis
loop clear expired node's dirty data in redis
"""
while True:
try:
@@ -567,16 +570,15 @@ class APIScheduler(object):
if node.expired(self.clear_expired_nodes_period):
expire_nodes.add(nodeid)
for nodeid in expire_nodes:
#logger.info(f"clear expired nodes: {nodeid}")
# logger.info(f"clear expired nodes: {nodeid}")
self.client.hdel(self.cluster_key, nodeid)
time.sleep(self.clear_expired_nodes_period)
except Exception:
logger.error(
"APIScheduler clear expired nodes error: {str(e)}")
logger.error("APIScheduler clear expired nodes error: {str(e)}")
def select_pd(self, req, nodes, role):
"""
select a prefill/decode/mixed node based on load states
select a prefill/decode/mixed node based on load states
"""
def select(req, nodes, blur_step):
@@ -587,10 +589,8 @@ class APIScheduler(object):
if node.load >= blur_max:
break
blur_idx = idx
node = random.choice(nodes[:blur_idx + 1])
logger.info(
f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}"
)
node = random.choice(nodes[: blur_idx + 1])
logger.info(f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}")
return node
if role == "prefill" or role == "mixed":
@@ -607,9 +607,9 @@ class APIScheduler(object):
raise Exception(f"Invalid Role: {role}")
class ResultWriter(object):
class ResultWriter:
"""
ResultWriter use an async thread to continue writer infer results to redis
ResultWriter use an async thread to continue writer infer results to redis
"""
def __init__(self, client, idx, batch, ttl=900):
@@ -627,7 +627,7 @@ class ResultWriter(object):
def put(self, key, items):
"""
put infer results to writer
put infer results to writer
"""
with self.cond:
for item in items:
@@ -636,7 +636,7 @@ class ResultWriter(object):
def run(self):
"""
continue batch write infer results to redis
continue batch write infer results to redis
"""
while True:
try:
@@ -644,9 +644,9 @@ class ResultWriter(object):
size = len(self.data)
if size == 0:
self.cond.wait()
#qsize = size
# qsize = size
size = min(size, self.batch)
#logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
# logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
groups = dict()
for i in range(size):
key, item = self.data.pop()
@@ -654,22 +654,22 @@ class ResultWriter(object):
groups[key] = []
groups[key].append(item)
for key, items in groups.items():
#s = time.time()
# s = time.time()
with self.client.pipeline() as pipe:
pipe.multi()
pipe.lpush(key, *items)
pipe.expire(key, math.ceil(self.ttl))
pipe.execute()
#self.client.lpush(key, *items)
#e = time.time()
#logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
# self.client.lpush(key, *items)
# e = time.time()
# logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
except Exception as e:
logger.error(f"ResultWriter write error: {str(e)}")
logger.error(f"ResultWriter write error: {e!s}")
class InferScheduler(object):
class InferScheduler:
"""
InferScheduler: get scheduled requests to local queue, write results to redis
InferScheduler: get scheduled requests to local queue, write results to redis
"""
def __init__(self, config):
@@ -682,20 +682,21 @@ class InferScheduler(object):
self.ttl = config.ttl
self.release_load_expire_period = config.release_load_expire_period
self.client = redis.Redis(host=config.redis_host,
port=config.redis_port,
password=config.redis_password)
self.client = redis.Redis(
host=config.redis_host,
port=config.redis_port,
password=config.redis_password,
)
self.reqs_queue = deque()
self.writers = []
def start(self, role, host, disaggregated):
"""
start backup threads
start backup threads
"""
for i in range(self.writer_parallel):
writer = ResultWriter(self.client, i, self.writer_batch_size,
self.ttl)
writer = ResultWriter(self.client, i, self.writer_batch_size, self.ttl)
writer.start()
self.writers.append(writer)
@@ -709,25 +710,24 @@ class InferScheduler(object):
self.report_thread = threading.Thread(target=self.routine_report)
self.report_thread.start()
self.expire_reqs_thread = threading.Thread(
target=self.loop_expire_reqs)
self.expire_reqs_thread = threading.Thread(target=self.loop_expire_reqs)
self.expire_reqs_thread.start()
def routine_report(self):
"""
routine report node info: load, health
routine report node info: load, health
"""
while True:
try:
info = self.node.serialize()
self.client.hset(self.cluster_key, self.nodeid, info)
time.sleep(self.sync_period / 1000.)
time.sleep(self.sync_period / 1000.0)
except Exception as e:
logger.error(f"InferScheduler routine report error: {str(e)}")
logger.error(f"InferScheduler routine report error: {e!s}")
def loop_expire_reqs(self):
"""
loop clear expired reqs
loop clear expired reqs
"""
while True:
try:
@@ -738,7 +738,7 @@ class InferScheduler(object):
def loop_get_reqs(self):
"""
loop get global scheduled reqs to local queue
loop get global scheduled reqs to local queue
"""
def select_writer(req):
@@ -764,23 +764,26 @@ class InferScheduler(object):
group = req.get("group", "")
req = Request.from_dict(req)
writer_idx = select_writer(req)
logger.info(
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
)
logger.info(f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}")
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
if self.role == "prefill" or self.role == "mixed":
self.reqs_queue.append(req)
self.node.add_req(req.request_id,
req.prompt_token_ids_len)
self.node.add_req(req.request_id, req.prompt_token_ids_len)
else:
self.node.add_req(req.request_id, 1)
except Exception as e:
logger.error(f"InferScheduler loop get reqs error: {str(e)}")
logger.error(f"InferScheduler loop get reqs error: {e!s}")
def get_requests(self, available_blocks, block_size,
reserved_output_blocks, max_num_batched_tokens, batch):
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch,
):
"""
get scheduled reqs from local reqs queue
get scheduled reqs from local reqs queue
"""
if len(self.reqs_queue) == 0:
return []
@@ -793,19 +796,16 @@ class InferScheduler(object):
try:
req = self.reqs_queue.popleft()
if cur_time - req.arrival_time > self.ttl:
logger.error(
f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests"
)
logger.error(f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests")
self.node.finish_req(req.request_id)
continue
current_prefill_tokens += req.prompt_token_ids_len
required_input_blocks = (req.prompt_token_ids_len +
block_size - 1) // block_size
required_input_blocks = (req.prompt_token_ids_len + block_size - 1) // block_size
required_blocks += required_input_blocks + reserved_output_blocks
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
self.reqs_queue.appendleft(req)
return reqs
#logger.info(f"Get Requests from Scheduler: {req.request_id}")
# logger.info(f"Get Requests from Scheduler: {req.request_id}")
reqs.append(req)
except Exception:
return reqs
@@ -813,16 +813,14 @@ class InferScheduler(object):
def put_results(self, results):
"""
put infer results to according writer's local queue
put infer results to according writer's local queue
"""
groups = dict()
req_ids = set()
for result in results:
if result.error_code != 200 or result.finished:
self.node.finish_req(result.request_id)
logger.info(
f"{result.request_id} finished, node load is {self.node.load}"
)
logger.info(f"{result.request_id} finished, node load is {self.node.load}")
req_ids.add(result.request_id)
@@ -837,7 +835,7 @@ class InferScheduler(object):
result.finished = False
result_str = orjson.dumps(result.to_dict())
#if self.role == "prefill" or result.error_code != 200 or result.finished:
# if self.role == "prefill" or result.error_code != 200 or result.finished:
# logger.info(f"Infer Put Finish Result: {result_str}")
groups[key].append(result_str)