mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -260,12 +260,13 @@ class ResultReader(object):
|
||||
ResultReader use an async thread to continue get infer result from redis
|
||||
"""
|
||||
|
||||
def __init__(self, client, idx, batch=200, ttl=900):
|
||||
def __init__(self, client, idx, batch=200, ttl=900, group=""):
|
||||
self.idx = idx
|
||||
self.batch = batch
|
||||
self.client = client
|
||||
self.data = deque()
|
||||
self.ttl = ttl
|
||||
self.group = group
|
||||
|
||||
self.reqs = dict()
|
||||
self.out_buffer = dict()
|
||||
@@ -380,15 +381,18 @@ class ResultReader(object):
|
||||
fetch infer results from redis for the give keys
|
||||
"""
|
||||
total = 0
|
||||
if self.group != "":
|
||||
keys = [self.group]
|
||||
for key in keys:
|
||||
#logger.info(f"Sync Results from Redis {key}")
|
||||
results = self.client.rpop(key, self.batch)
|
||||
if results is None or len(results) == 0:
|
||||
continue
|
||||
#logger.info(f"Rpop {self.idx}: {len(results)}")
|
||||
#logger.info(f"Rpop {key} {self.idx}: {len(results)}")
|
||||
total += len(results)
|
||||
for result in results:
|
||||
try:
|
||||
#logger.info(f"Scheduler Get Results: {result}")
|
||||
# logger.info(f"Scheduler Get Results: {result.request_id}")
|
||||
data = orjson.loads(result)
|
||||
result = RequestOutput.from_dict(data)
|
||||
self.data.appendleft(result)
|
||||
@@ -425,8 +429,9 @@ class APIScheduler(object):
|
||||
start backup threads
|
||||
"""
|
||||
for i in range(self.reader_parallel):
|
||||
group = f"{self.nodeid}-{i}"
|
||||
reader = ResultReader(self.client, i, self.reader_batch_size,
|
||||
self.ttl)
|
||||
self.ttl, group)
|
||||
self.readers.append(reader)
|
||||
|
||||
self.clear_expired_nodes_thread = threading.Thread(
|
||||
@@ -481,15 +486,16 @@ class APIScheduler(object):
|
||||
|
||||
reader = self.readers[reader_idx]
|
||||
reader.add_req(req)
|
||||
group = self.readers[reader_idx].group
|
||||
reader_idx = (reader_idx + 1) % len(self.readers)
|
||||
|
||||
self.schedule(req, pnodes, dnodes, mnodes)
|
||||
self.schedule(req, pnodes, dnodes, mnodes, group)
|
||||
except IndexError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"APIScheduler Schedule req error: {str(e)}")
|
||||
|
||||
def schedule(self, req, pnodes, dnodes, mnodes):
|
||||
def schedule(self, req, pnodes, dnodes, mnodes, group=""):
|
||||
"""
|
||||
schedule an req to according redis node queue
|
||||
"""
|
||||
@@ -498,7 +504,9 @@ class APIScheduler(object):
|
||||
pnode = self.select_pd(req, pnodes, "prefill")
|
||||
if pnode.role == "mixed":
|
||||
req.disaggregate_info = None
|
||||
req_str = orjson.dumps(req.to_dict())
|
||||
req_dict = req.to_dict()
|
||||
req_dict["group"] = group
|
||||
req_str = orjson.dumps(req_dict)
|
||||
pkey = f"ReqQ_{pnode.nodeid}"
|
||||
#logger.info(f"Schedule Req {req_str} to Mixed")
|
||||
self.client.lpush(pkey, req_str)
|
||||
@@ -518,7 +526,9 @@ class APIScheduler(object):
|
||||
disaggregated["transfer_protocol"] = transfer_protocol[0]
|
||||
req.disaggregate_info = disaggregated
|
||||
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
|
||||
req_str = orjson.dumps(req.to_dict())
|
||||
req_dict = req.to_dict()
|
||||
req_dict["group"] = group
|
||||
req_str = orjson.dumps(req_dict)
|
||||
#logger.info(f"Schedule Req {req_str}")
|
||||
self.client.lpush(dkey, req_str)
|
||||
self.client.lpush(pkey, req_str)
|
||||
@@ -634,7 +644,9 @@ class ResultWriter(object):
|
||||
size = len(self.data)
|
||||
if size == 0:
|
||||
self.cond.wait()
|
||||
#qsize = size
|
||||
size = min(size, self.batch)
|
||||
#logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
|
||||
groups = dict()
|
||||
for i in range(size):
|
||||
key, item = self.data.pop()
|
||||
@@ -749,12 +761,13 @@ class InferScheduler(object):
|
||||
|
||||
for req_str in reqs:
|
||||
req = orjson.loads(req_str)
|
||||
group = req.get("group", "")
|
||||
req = Request.from_dict(req)
|
||||
writer_idx = select_writer(req)
|
||||
logger.info(
|
||||
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
|
||||
)
|
||||
req.request_id = f"{req.request_id}#{writer_idx}"
|
||||
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
|
||||
if self.role == "prefill" or self.role == "mixed":
|
||||
self.reqs_queue.append(req)
|
||||
self.node.add_req(req.request_id,
|
||||
@@ -813,10 +826,10 @@ class InferScheduler(object):
|
||||
|
||||
req_ids.add(result.request_id)
|
||||
|
||||
req_id, idx = result.request_id.split("#")
|
||||
req_id, idx, group = result.request_id.split("#")
|
||||
result.request_id = req_id
|
||||
|
||||
key = (req_id, int(idx))
|
||||
key = (req_id if group == "" else group, int(idx))
|
||||
if key not in groups:
|
||||
groups[key] = list()
|
||||
|
||||
|
Reference in New Issue
Block a user