mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import math
|
||||
@@ -25,34 +26,40 @@ from typing import List
|
||||
import orjson
|
||||
import redis
|
||||
|
||||
from fastdeploy.engine.request import (CompletionOutput, Request,
|
||||
RequestMetrics, RequestOutput)
|
||||
from fastdeploy.engine.request import (
|
||||
CompletionOutput,
|
||||
Request,
|
||||
RequestMetrics,
|
||||
RequestOutput,
|
||||
)
|
||||
from fastdeploy.utils import scheduler_logger as logger
|
||||
|
||||
|
||||
class SplitWiseSchedulerConfig(object):
|
||||
class SplitWiseSchedulerConfig:
|
||||
"""SplitWise Scheduler Configuration"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodeid=None,
|
||||
host="127.0.0.1", # redis host
|
||||
port=6379, # redis port
|
||||
password=None, # redis password
|
||||
topic="fd", # redis topic
|
||||
ttl=900,
|
||||
release_load_expire_period=600, #s
|
||||
sync_period=5, #ms
|
||||
expire_period=3000, #ms
|
||||
clear_expired_nodes_period=60, #s
|
||||
reader_parallel=4,
|
||||
reader_batch_size=200,
|
||||
writer_parallel=4,
|
||||
writer_batch_size=200,
|
||||
**kwargs):
|
||||
self,
|
||||
nodeid=None,
|
||||
host="127.0.0.1", # redis host
|
||||
port=6379, # redis port
|
||||
password=None, # redis password
|
||||
topic="fd", # redis topic
|
||||
ttl=900,
|
||||
release_load_expire_period=600, # s
|
||||
sync_period=5, # ms
|
||||
expire_period=3000, # ms
|
||||
clear_expired_nodes_period=60, # s
|
||||
reader_parallel=4,
|
||||
reader_batch_size=200,
|
||||
writer_parallel=4,
|
||||
writer_batch_size=200,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if nodeid is None:
|
||||
import uuid
|
||||
|
||||
nodeid = str(uuid.uuid4())
|
||||
self.nodeid = nodeid
|
||||
|
||||
@@ -64,7 +71,7 @@ class SplitWiseSchedulerConfig(object):
|
||||
self.release_load_expire_period = release_load_expire_period
|
||||
|
||||
self.sync_period = sync_period
|
||||
self.expire_period = expire_period / 1000.
|
||||
self.expire_period = expire_period / 1000.0
|
||||
self.clear_expired_nodes_period = clear_expired_nodes_period
|
||||
self.reader_parallel = reader_parallel
|
||||
self.reader_batch_size = reader_batch_size
|
||||
@@ -82,13 +89,12 @@ class SplitWiseSchedulerConfig(object):
|
||||
logger.info("LocalScheduler Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info(
|
||||
"=============================================================")
|
||||
logger.info("=============================================================")
|
||||
|
||||
|
||||
class SplitWiseScheduler(object):
|
||||
class SplitWiseScheduler:
|
||||
"""
|
||||
SplitWise Scheduler
|
||||
SplitWise Scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@@ -97,68 +103,73 @@ class SplitWiseScheduler(object):
|
||||
|
||||
def start(self, role, host, disaggregated):
|
||||
"""
|
||||
Start APIScheduler and InferScheduler backup threads
|
||||
Start APIScheduler and InferScheduler backup threads
|
||||
"""
|
||||
logger.info(
|
||||
f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}"
|
||||
)
|
||||
logger.info(f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}")
|
||||
self.infer.start(role, host, disaggregated)
|
||||
self.scheduler.start()
|
||||
|
||||
def reset_nodeid(self, nodeid):
|
||||
"""
|
||||
reset node id
|
||||
reset node id
|
||||
"""
|
||||
self.scheduler.nodeid = nodeid
|
||||
self.infer.nodeid = nodeid
|
||||
|
||||
def put_requests(self, reqs: List[Request]):
|
||||
"""
|
||||
put requests to global splitwise scheduler
|
||||
put requests to global splitwise scheduler
|
||||
"""
|
||||
return self.scheduler.put_requests(reqs)
|
||||
|
||||
def get_results(self, request_ids=[]):
|
||||
"""
|
||||
get results from global splitwise scheduler
|
||||
get results from global splitwise scheduler
|
||||
"""
|
||||
return self.scheduler.get_results()
|
||||
|
||||
def get_requests(self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1):
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
):
|
||||
"""
|
||||
get scheduled requests from global spltiwise scheduler
|
||||
get scheduled requests from global spltiwise scheduler
|
||||
"""
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
logger.info(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||
)
|
||||
return []
|
||||
return self.infer.get_requests(available_blocks, block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens, batch)
|
||||
return self.infer.get_requests(
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch,
|
||||
)
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
put results to global splitwise scheduler
|
||||
put results to global splitwise scheduler
|
||||
"""
|
||||
return self.infer.put_results(results)
|
||||
|
||||
|
||||
class NodeInfo(object):
|
||||
class NodeInfo:
|
||||
"""
|
||||
Infer Node Info: load, rdma/ipc info
|
||||
Infer Node Info: load, rdma/ipc info
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load_from(self, nodeid, info):
|
||||
"""
|
||||
load node info from seiralized string
|
||||
load node info from seiralized string
|
||||
"""
|
||||
health = orjson.loads(info)
|
||||
ts = health["ts"]
|
||||
@@ -168,8 +179,7 @@ class NodeInfo(object):
|
||||
disaggregated = health["disaggregated"]
|
||||
return NodeInfo(nodeid, role, host, disaggregated, load, ts)
|
||||
|
||||
def __init__(self, nodeid, role, host, disaggregated, load,
|
||||
ts=time.time()):
|
||||
def __init__(self, nodeid, role, host, disaggregated, load, ts=time.time()):
|
||||
self.nodeid = nodeid
|
||||
self.ts = ts
|
||||
self.host = host
|
||||
@@ -184,14 +194,14 @@ class NodeInfo(object):
|
||||
|
||||
def expired(self, expire_period):
|
||||
"""
|
||||
APIScheduler used to check if the node is expired
|
||||
APIScheduler used to check if the node is expired
|
||||
"""
|
||||
now = time.time()
|
||||
return (now - self.ts) > expire_period
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
InferScheduler used to sync load
|
||||
InferScheduler used to sync load
|
||||
"""
|
||||
self.ts = time.time()
|
||||
health = {
|
||||
@@ -199,7 +209,7 @@ class NodeInfo(object):
|
||||
"role": self.role,
|
||||
"load": self.load,
|
||||
"host": self.host,
|
||||
"disaggregated": self.disaggregated
|
||||
"disaggregated": self.disaggregated,
|
||||
}
|
||||
return orjson.dumps(health)
|
||||
|
||||
@@ -208,7 +218,7 @@ class NodeInfo(object):
|
||||
|
||||
def expire_reqs(self, ttl):
|
||||
"""
|
||||
InferScheduler used to clear expired reqs
|
||||
InferScheduler used to clear expired reqs
|
||||
"""
|
||||
cur_time = time.time()
|
||||
with self.lock:
|
||||
@@ -216,9 +226,7 @@ class NodeInfo(object):
|
||||
for req_id, pairs in self.reqs.items():
|
||||
load, arrival_time = pairs
|
||||
if cur_time - arrival_time > ttl:
|
||||
logger.error(
|
||||
f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})"
|
||||
)
|
||||
logger.error(f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})")
|
||||
expire_reqs.add((req_id, load))
|
||||
for req_id, load in expire_reqs:
|
||||
if req_id in self.reqs:
|
||||
@@ -227,7 +235,7 @@ class NodeInfo(object):
|
||||
|
||||
def add_req(self, req_id, load):
|
||||
"""
|
||||
InferScheduler used to record scheduled reqs(waiting or running)
|
||||
InferScheduler used to record scheduled reqs(waiting or running)
|
||||
"""
|
||||
with self.lock:
|
||||
if req_id not in self.reqs:
|
||||
@@ -236,7 +244,7 @@ class NodeInfo(object):
|
||||
|
||||
def update_req_timestamp(self, req_ids):
|
||||
"""
|
||||
InferScheduler used to update reqs timestamp
|
||||
InferScheduler used to update reqs timestamp
|
||||
"""
|
||||
cur_time = time.time()
|
||||
with self.lock:
|
||||
@@ -246,7 +254,7 @@ class NodeInfo(object):
|
||||
|
||||
def finish_req(self, req_id):
|
||||
"""
|
||||
InferScheduler used to clear finished reqs
|
||||
InferScheduler used to clear finished reqs
|
||||
"""
|
||||
with self.lock:
|
||||
if req_id in self.reqs:
|
||||
@@ -255,9 +263,9 @@ class NodeInfo(object):
|
||||
del self.reqs[req_id]
|
||||
|
||||
|
||||
class ResultReader(object):
|
||||
class ResultReader:
|
||||
"""
|
||||
ResultReader use an async thread to continue get infer result from redis
|
||||
ResultReader use an async thread to continue get infer result from redis
|
||||
"""
|
||||
|
||||
def __init__(self, client, idx, batch=200, ttl=900, group=""):
|
||||
@@ -277,7 +285,7 @@ class ResultReader(object):
|
||||
|
||||
def add_req(self, req):
|
||||
"""
|
||||
add a req to reader, reader will async fetch infer result from redis
|
||||
add a req to reader, reader will async fetch infer result from redis
|
||||
"""
|
||||
with self.lock:
|
||||
self.reqs[req.request_id] = {"arrival_time": req.arrival_time}
|
||||
@@ -285,8 +293,8 @@ class ResultReader(object):
|
||||
|
||||
def read(self):
|
||||
"""
|
||||
batch read infer results
|
||||
returns: dict(req_id, [ResultOutput])
|
||||
batch read infer results
|
||||
returns: dict(req_id, [ResultOutput])
|
||||
"""
|
||||
items = []
|
||||
size = len(self.data)
|
||||
@@ -335,7 +343,7 @@ class ResultReader(object):
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
continue fetch infer results from redis
|
||||
continue fetch infer results from redis
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
@@ -344,21 +352,19 @@ class ResultReader(object):
|
||||
with self.lock:
|
||||
expired_reqs = set()
|
||||
for req_id, req in self.reqs.items():
|
||||
if cur_time - req.get("arrival_time",
|
||||
cur_time) > self.ttl:
|
||||
if cur_time - req.get("arrival_time", cur_time) > self.ttl:
|
||||
result = RequestOutput(
|
||||
request_id=req_id,
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=CompletionOutput(-1, -1, []),
|
||||
metrics=RequestMetrics(
|
||||
arrival_time=req["arrival_time"]),
|
||||
metrics=RequestMetrics(arrival_time=req["arrival_time"]),
|
||||
error_code=500,
|
||||
error_msg=f"Req({req_id}) is expired({self.ttl})")
|
||||
error_msg=f"Req({req_id}) is expired({self.ttl})",
|
||||
)
|
||||
self.data.appendleft(result)
|
||||
|
||||
logger.error(
|
||||
f"Req({req_id}) is expired({self.ttl})")
|
||||
logger.error(f"Req({req_id}) is expired({self.ttl})")
|
||||
expired_reqs.add(req_id)
|
||||
continue
|
||||
keys.append(req_id)
|
||||
@@ -373,22 +379,21 @@ class ResultReader(object):
|
||||
if total == 0:
|
||||
time.sleep(0.01)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"ResultsReader{self.idx} sync results error: {str(e)}")
|
||||
logger.error(f"ResultsReader{self.idx} sync results error: {e!s}")
|
||||
|
||||
def sync_results(self, keys):
|
||||
"""
|
||||
fetch infer results from redis for the give keys
|
||||
fetch infer results from redis for the give keys
|
||||
"""
|
||||
total = 0
|
||||
if self.group != "":
|
||||
keys = [self.group]
|
||||
for key in keys:
|
||||
#logger.info(f"Sync Results from Redis {key}")
|
||||
# logger.info(f"Sync Results from Redis {key}")
|
||||
results = self.client.rpop(key, self.batch)
|
||||
if results is None or len(results) == 0:
|
||||
continue
|
||||
#logger.info(f"Rpop {key} {self.idx}: {len(results)}")
|
||||
# logger.info(f"Rpop {key} {self.idx}: {len(results)}")
|
||||
total += len(results)
|
||||
for result in results:
|
||||
try:
|
||||
@@ -401,9 +406,9 @@ class ResultReader(object):
|
||||
return total
|
||||
|
||||
|
||||
class APIScheduler(object):
|
||||
class APIScheduler:
|
||||
"""
|
||||
APIScheduler: put requests to global schedule, and get recording infer results
|
||||
APIScheduler: put requests to global schedule, and get recording infer results
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@@ -416,9 +421,11 @@ class APIScheduler(object):
|
||||
self.topic = config.redis_topic
|
||||
self.cluster_key = f"{self.topic}.cluster"
|
||||
|
||||
self.client = redis.Redis(host=config.redis_host,
|
||||
port=config.redis_port,
|
||||
password=config.redis_password)
|
||||
self.client = redis.Redis(
|
||||
host=config.redis_host,
|
||||
port=config.redis_port,
|
||||
password=config.redis_password,
|
||||
)
|
||||
|
||||
self.req_cond = threading.Condition()
|
||||
self.reqs_queue = deque()
|
||||
@@ -426,16 +433,14 @@ class APIScheduler(object):
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
start backup threads
|
||||
start backup threads
|
||||
"""
|
||||
for i in range(self.reader_parallel):
|
||||
group = f"{self.nodeid}-{i}"
|
||||
reader = ResultReader(self.client, i, self.reader_batch_size,
|
||||
self.ttl, group)
|
||||
reader = ResultReader(self.client, i, self.reader_batch_size, self.ttl, group)
|
||||
self.readers.append(reader)
|
||||
|
||||
self.clear_expired_nodes_thread = threading.Thread(
|
||||
target=self.loop_clear_expired_nodes)
|
||||
self.clear_expired_nodes_thread = threading.Thread(target=self.loop_clear_expired_nodes)
|
||||
self.clear_expired_nodes_thread.start()
|
||||
|
||||
self.schedule_thread = threading.Thread(target=self.loop_schedule)
|
||||
@@ -443,7 +448,7 @@ class APIScheduler(object):
|
||||
|
||||
def put_requests(self, reqs):
|
||||
"""
|
||||
put requests to local req queue. reqs will be async scheduled
|
||||
put requests to local req queue. reqs will be async scheduled
|
||||
"""
|
||||
ret = []
|
||||
with self.req_cond:
|
||||
@@ -455,7 +460,7 @@ class APIScheduler(object):
|
||||
|
||||
def get_results(self):
|
||||
"""
|
||||
get infer results from local queue. results is async fetched from redis
|
||||
get infer results from local queue. results is async fetched from redis
|
||||
"""
|
||||
outputs = dict()
|
||||
for reader in self.readers:
|
||||
@@ -465,7 +470,7 @@ class APIScheduler(object):
|
||||
|
||||
def loop_schedule(self):
|
||||
"""
|
||||
loop schedule req based on global load states.
|
||||
loop schedule req based on global load states.
|
||||
"""
|
||||
reader_idx = 0
|
||||
while True:
|
||||
@@ -493,11 +498,11 @@ class APIScheduler(object):
|
||||
except IndexError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"APIScheduler Schedule req error: {str(e)}")
|
||||
logger.error(f"APIScheduler Schedule req error: {e!s}")
|
||||
|
||||
def schedule(self, req, pnodes, dnodes, mnodes, group=""):
|
||||
"""
|
||||
schedule an req to according redis node queue
|
||||
schedule an req to according redis node queue
|
||||
"""
|
||||
pnodes.extend(mnodes)
|
||||
pnodes.sort()
|
||||
@@ -508,16 +513,14 @@ class APIScheduler(object):
|
||||
req_dict["group"] = group
|
||||
req_str = orjson.dumps(req_dict)
|
||||
pkey = f"ReqQ_{pnode.nodeid}"
|
||||
#logger.info(f"Schedule Req {req_str} to Mixed")
|
||||
# logger.info(f"Schedule Req {req_str} to Mixed")
|
||||
self.client.lpush(pkey, req_str)
|
||||
else:
|
||||
dnodes.sort()
|
||||
dnode = self.select_pd(req, dnodes, "decode")
|
||||
disaggregated = copy.deepcopy(dnode.disaggregated)
|
||||
transfer_protocol = disaggregated["transfer_protocol"]
|
||||
if len(
|
||||
transfer_protocol
|
||||
) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
|
||||
if len(transfer_protocol) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
|
||||
if pnode.host == dnode.host:
|
||||
disaggregated["transfer_protocol"] = "ipc"
|
||||
else:
|
||||
@@ -529,13 +532,13 @@ class APIScheduler(object):
|
||||
req_dict = req.to_dict()
|
||||
req_dict["group"] = group
|
||||
req_str = orjson.dumps(req_dict)
|
||||
#logger.info(f"Schedule Req {req_str}")
|
||||
# logger.info(f"Schedule Req {req_str}")
|
||||
self.client.lpush(dkey, req_str)
|
||||
self.client.lpush(pkey, req_str)
|
||||
|
||||
def sync_cluster(self):
|
||||
"""
|
||||
fetch cluster load states from redis
|
||||
fetch cluster load states from redis
|
||||
"""
|
||||
clusters = self.client.hgetall(self.cluster_key)
|
||||
pnodes, dnodes, mnodes = [], [], []
|
||||
@@ -556,7 +559,7 @@ class APIScheduler(object):
|
||||
|
||||
def loop_clear_expired_nodes(self):
|
||||
"""
|
||||
loop clear expired node's dirty data in redis
|
||||
loop clear expired node's dirty data in redis
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
@@ -567,16 +570,15 @@ class APIScheduler(object):
|
||||
if node.expired(self.clear_expired_nodes_period):
|
||||
expire_nodes.add(nodeid)
|
||||
for nodeid in expire_nodes:
|
||||
#logger.info(f"clear expired nodes: {nodeid}")
|
||||
# logger.info(f"clear expired nodes: {nodeid}")
|
||||
self.client.hdel(self.cluster_key, nodeid)
|
||||
time.sleep(self.clear_expired_nodes_period)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"APIScheduler clear expired nodes error: {str(e)}")
|
||||
logger.error("APIScheduler clear expired nodes error: {str(e)}")
|
||||
|
||||
def select_pd(self, req, nodes, role):
|
||||
"""
|
||||
select a prefill/decode/mixed node based on load states
|
||||
select a prefill/decode/mixed node based on load states
|
||||
"""
|
||||
|
||||
def select(req, nodes, blur_step):
|
||||
@@ -587,10 +589,8 @@ class APIScheduler(object):
|
||||
if node.load >= blur_max:
|
||||
break
|
||||
blur_idx = idx
|
||||
node = random.choice(nodes[:blur_idx + 1])
|
||||
logger.info(
|
||||
f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}"
|
||||
)
|
||||
node = random.choice(nodes[: blur_idx + 1])
|
||||
logger.info(f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}")
|
||||
return node
|
||||
|
||||
if role == "prefill" or role == "mixed":
|
||||
@@ -607,9 +607,9 @@ class APIScheduler(object):
|
||||
raise Exception(f"Invalid Role: {role}")
|
||||
|
||||
|
||||
class ResultWriter(object):
|
||||
class ResultWriter:
|
||||
"""
|
||||
ResultWriter use an async thread to continue writer infer results to redis
|
||||
ResultWriter use an async thread to continue writer infer results to redis
|
||||
"""
|
||||
|
||||
def __init__(self, client, idx, batch, ttl=900):
|
||||
@@ -627,7 +627,7 @@ class ResultWriter(object):
|
||||
|
||||
def put(self, key, items):
|
||||
"""
|
||||
put infer results to writer
|
||||
put infer results to writer
|
||||
"""
|
||||
with self.cond:
|
||||
for item in items:
|
||||
@@ -636,7 +636,7 @@ class ResultWriter(object):
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
continue batch write infer results to redis
|
||||
continue batch write infer results to redis
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
@@ -644,9 +644,9 @@ class ResultWriter(object):
|
||||
size = len(self.data)
|
||||
if size == 0:
|
||||
self.cond.wait()
|
||||
#qsize = size
|
||||
# qsize = size
|
||||
size = min(size, self.batch)
|
||||
#logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
|
||||
# logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
|
||||
groups = dict()
|
||||
for i in range(size):
|
||||
key, item = self.data.pop()
|
||||
@@ -654,22 +654,22 @@ class ResultWriter(object):
|
||||
groups[key] = []
|
||||
groups[key].append(item)
|
||||
for key, items in groups.items():
|
||||
#s = time.time()
|
||||
# s = time.time()
|
||||
with self.client.pipeline() as pipe:
|
||||
pipe.multi()
|
||||
pipe.lpush(key, *items)
|
||||
pipe.expire(key, math.ceil(self.ttl))
|
||||
pipe.execute()
|
||||
#self.client.lpush(key, *items)
|
||||
#e = time.time()
|
||||
#logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
|
||||
# self.client.lpush(key, *items)
|
||||
# e = time.time()
|
||||
# logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
|
||||
except Exception as e:
|
||||
logger.error(f"ResultWriter write error: {str(e)}")
|
||||
logger.error(f"ResultWriter write error: {e!s}")
|
||||
|
||||
|
||||
class InferScheduler(object):
|
||||
class InferScheduler:
|
||||
"""
|
||||
InferScheduler: get scheduled requests to local queue, write results to redis
|
||||
InferScheduler: get scheduled requests to local queue, write results to redis
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@@ -682,20 +682,21 @@ class InferScheduler(object):
|
||||
self.ttl = config.ttl
|
||||
self.release_load_expire_period = config.release_load_expire_period
|
||||
|
||||
self.client = redis.Redis(host=config.redis_host,
|
||||
port=config.redis_port,
|
||||
password=config.redis_password)
|
||||
self.client = redis.Redis(
|
||||
host=config.redis_host,
|
||||
port=config.redis_port,
|
||||
password=config.redis_password,
|
||||
)
|
||||
|
||||
self.reqs_queue = deque()
|
||||
self.writers = []
|
||||
|
||||
def start(self, role, host, disaggregated):
|
||||
"""
|
||||
start backup threads
|
||||
start backup threads
|
||||
"""
|
||||
for i in range(self.writer_parallel):
|
||||
writer = ResultWriter(self.client, i, self.writer_batch_size,
|
||||
self.ttl)
|
||||
writer = ResultWriter(self.client, i, self.writer_batch_size, self.ttl)
|
||||
writer.start()
|
||||
self.writers.append(writer)
|
||||
|
||||
@@ -709,25 +710,24 @@ class InferScheduler(object):
|
||||
self.report_thread = threading.Thread(target=self.routine_report)
|
||||
self.report_thread.start()
|
||||
|
||||
self.expire_reqs_thread = threading.Thread(
|
||||
target=self.loop_expire_reqs)
|
||||
self.expire_reqs_thread = threading.Thread(target=self.loop_expire_reqs)
|
||||
self.expire_reqs_thread.start()
|
||||
|
||||
def routine_report(self):
|
||||
"""
|
||||
routine report node info: load, health
|
||||
routine report node info: load, health
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
info = self.node.serialize()
|
||||
self.client.hset(self.cluster_key, self.nodeid, info)
|
||||
time.sleep(self.sync_period / 1000.)
|
||||
time.sleep(self.sync_period / 1000.0)
|
||||
except Exception as e:
|
||||
logger.error(f"InferScheduler routine report error: {str(e)}")
|
||||
logger.error(f"InferScheduler routine report error: {e!s}")
|
||||
|
||||
def loop_expire_reqs(self):
|
||||
"""
|
||||
loop clear expired reqs
|
||||
loop clear expired reqs
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
@@ -738,7 +738,7 @@ class InferScheduler(object):
|
||||
|
||||
def loop_get_reqs(self):
|
||||
"""
|
||||
loop get global scheduled reqs to local queue
|
||||
loop get global scheduled reqs to local queue
|
||||
"""
|
||||
|
||||
def select_writer(req):
|
||||
@@ -764,23 +764,26 @@ class InferScheduler(object):
|
||||
group = req.get("group", "")
|
||||
req = Request.from_dict(req)
|
||||
writer_idx = select_writer(req)
|
||||
logger.info(
|
||||
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
|
||||
)
|
||||
logger.info(f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}")
|
||||
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
|
||||
if self.role == "prefill" or self.role == "mixed":
|
||||
self.reqs_queue.append(req)
|
||||
self.node.add_req(req.request_id,
|
||||
req.prompt_token_ids_len)
|
||||
self.node.add_req(req.request_id, req.prompt_token_ids_len)
|
||||
else:
|
||||
self.node.add_req(req.request_id, 1)
|
||||
except Exception as e:
|
||||
logger.error(f"InferScheduler loop get reqs error: {str(e)}")
|
||||
logger.error(f"InferScheduler loop get reqs error: {e!s}")
|
||||
|
||||
def get_requests(self, available_blocks, block_size,
|
||||
reserved_output_blocks, max_num_batched_tokens, batch):
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch,
|
||||
):
|
||||
"""
|
||||
get scheduled reqs from local reqs queue
|
||||
get scheduled reqs from local reqs queue
|
||||
"""
|
||||
if len(self.reqs_queue) == 0:
|
||||
return []
|
||||
@@ -793,19 +796,16 @@ class InferScheduler(object):
|
||||
try:
|
||||
req = self.reqs_queue.popleft()
|
||||
if cur_time - req.arrival_time > self.ttl:
|
||||
logger.error(
|
||||
f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests"
|
||||
)
|
||||
logger.error(f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests")
|
||||
self.node.finish_req(req.request_id)
|
||||
continue
|
||||
current_prefill_tokens += req.prompt_token_ids_len
|
||||
required_input_blocks = (req.prompt_token_ids_len +
|
||||
block_size - 1) // block_size
|
||||
required_input_blocks = (req.prompt_token_ids_len + block_size - 1) // block_size
|
||||
required_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
||||
self.reqs_queue.appendleft(req)
|
||||
return reqs
|
||||
#logger.info(f"Get Requests from Scheduler: {req.request_id}")
|
||||
# logger.info(f"Get Requests from Scheduler: {req.request_id}")
|
||||
reqs.append(req)
|
||||
except Exception:
|
||||
return reqs
|
||||
@@ -813,16 +813,14 @@ class InferScheduler(object):
|
||||
|
||||
def put_results(self, results):
|
||||
"""
|
||||
put infer results to according writer's local queue
|
||||
put infer results to according writer's local queue
|
||||
"""
|
||||
groups = dict()
|
||||
req_ids = set()
|
||||
for result in results:
|
||||
if result.error_code != 200 or result.finished:
|
||||
self.node.finish_req(result.request_id)
|
||||
logger.info(
|
||||
f"{result.request_id} finished, node load is {self.node.load}"
|
||||
)
|
||||
logger.info(f"{result.request_id} finished, node load is {self.node.load}")
|
||||
|
||||
req_ids.add(result.request_id)
|
||||
|
||||
@@ -837,7 +835,7 @@ class InferScheduler(object):
|
||||
result.finished = False
|
||||
|
||||
result_str = orjson.dumps(result.to_dict())
|
||||
#if self.role == "prefill" or result.error_code != 200 or result.finished:
|
||||
# if self.role == "prefill" or result.error_code != 200 or result.finished:
|
||||
# logger.info(f"Infer Put Finish Result: {result_str}")
|
||||
groups[key].append(result_str)
|
||||
|
||||
|
Reference in New Issue
Block a user