[Sync] Update to latest code (#2679)

* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-03 15:43:53 +08:00
committed by GitHub
parent d222248d00
commit 05c670e593
95 changed files with 9916 additions and 1312 deletions

View File

@@ -260,12 +260,13 @@ class ResultReader(object):
ResultReader use an async thread to continue get infer result from redis
"""
def __init__(self, client, idx, batch=200, ttl=900):
def __init__(self, client, idx, batch=200, ttl=900, group=""):
self.idx = idx
self.batch = batch
self.client = client
self.data = deque()
self.ttl = ttl
self.group = group
self.reqs = dict()
self.out_buffer = dict()
@@ -380,15 +381,18 @@ class ResultReader(object):
fetch infer results from redis for the give keys
"""
total = 0
if self.group != "":
keys = [self.group]
for key in keys:
#logger.info(f"Sync Results from Redis {key}")
results = self.client.rpop(key, self.batch)
if results is None or len(results) == 0:
continue
#logger.info(f"Rpop {self.idx}: {len(results)}")
#logger.info(f"Rpop {key} {self.idx}: {len(results)}")
total += len(results)
for result in results:
try:
#logger.info(f"Scheduler Get Results: {result}")
# logger.info(f"Scheduler Get Results: {result.request_id}")
data = orjson.loads(result)
result = RequestOutput.from_dict(data)
self.data.appendleft(result)
@@ -425,8 +429,9 @@ class APIScheduler(object):
start backup threads
"""
for i in range(self.reader_parallel):
group = f"{self.nodeid}-{i}"
reader = ResultReader(self.client, i, self.reader_batch_size,
self.ttl)
self.ttl, group)
self.readers.append(reader)
self.clear_expired_nodes_thread = threading.Thread(
@@ -481,15 +486,16 @@ class APIScheduler(object):
reader = self.readers[reader_idx]
reader.add_req(req)
group = self.readers[reader_idx].group
reader_idx = (reader_idx + 1) % len(self.readers)
self.schedule(req, pnodes, dnodes, mnodes)
self.schedule(req, pnodes, dnodes, mnodes, group)
except IndexError:
continue
except Exception as e:
logger.error(f"APIScheduler Schedule req error: {str(e)}")
def schedule(self, req, pnodes, dnodes, mnodes):
def schedule(self, req, pnodes, dnodes, mnodes, group=""):
"""
schedule an req to according redis node queue
"""
@@ -498,7 +504,9 @@ class APIScheduler(object):
pnode = self.select_pd(req, pnodes, "prefill")
if pnode.role == "mixed":
req.disaggregate_info = None
req_str = orjson.dumps(req.to_dict())
req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
pkey = f"ReqQ_{pnode.nodeid}"
#logger.info(f"Schedule Req {req_str} to Mixed")
self.client.lpush(pkey, req_str)
@@ -518,7 +526,9 @@ class APIScheduler(object):
disaggregated["transfer_protocol"] = transfer_protocol[0]
req.disaggregate_info = disaggregated
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_str = orjson.dumps(req.to_dict())
req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
#logger.info(f"Schedule Req {req_str}")
self.client.lpush(dkey, req_str)
self.client.lpush(pkey, req_str)
@@ -634,7 +644,9 @@ class ResultWriter(object):
size = len(self.data)
if size == 0:
self.cond.wait()
#qsize = size
size = min(size, self.batch)
#logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
groups = dict()
for i in range(size):
key, item = self.data.pop()
@@ -749,12 +761,13 @@ class InferScheduler(object):
for req_str in reqs:
req = orjson.loads(req_str)
group = req.get("group", "")
req = Request.from_dict(req)
writer_idx = select_writer(req)
logger.info(
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
)
req.request_id = f"{req.request_id}#{writer_idx}"
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
if self.role == "prefill" or self.role == "mixed":
self.reqs_queue.append(req)
self.node.add_req(req.request_id,
@@ -813,10 +826,10 @@ class InferScheduler(object):
req_ids.add(result.request_id)
req_id, idx = result.request_id.split("#")
req_id, idx, group = result.request_id.split("#")
result.request_id = req_id
key = (req_id, int(idx))
key = (req_id if group == "" else group, int(idx))
if key not in groups:
groups[key] = list()