mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -68,8 +68,7 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.LINGER, 0)
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(
|
||||
f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
@@ -177,8 +176,7 @@ class SplitwiseConnector:
|
||||
for port in self.cfg.innode_prefill_ports:
|
||||
if port not in self.connect_innode_instances:
|
||||
self.create_connection(port)
|
||||
if self.connect_innode_instances[
|
||||
port].available_prefill_instances.qsize() > 0:
|
||||
if self.connect_innode_instances[port].available_prefill_instances.qsize() > 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -199,15 +197,15 @@ class SplitwiseConnector:
|
||||
if self.connect_innode_instances[port].get_prefill_instances() == 1:
|
||||
for task in tasks:
|
||||
task.disaggregate_info = {
|
||||
"role": "prefill",
|
||||
"role": "prefill",
|
||||
"transfer_protocol": "ipc",
|
||||
"cache_info": {
|
||||
"ipc": {
|
||||
"ip": "0.0.0.0",
|
||||
"port": self.cfg.engine_worker_queue_port,
|
||||
"current_id": current_id
|
||||
"current_id": current_id,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("prefill", tasks))
|
||||
current_port = port
|
||||
@@ -229,9 +227,9 @@ class SplitwiseConnector:
|
||||
"ipc": {
|
||||
"ip": "0.0.0.0",
|
||||
"port": current_port,
|
||||
"current_id": current_id
|
||||
"current_id": current_id,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def send_splitwise_tasks(self, tasks, current_id):
|
||||
@@ -254,21 +252,20 @@ class SplitwiseConnector:
|
||||
|
||||
if task.disaggregate_info["transfer_protocol"] == "ipc":
|
||||
addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
|
||||
task.disaggregate_info["cache_info"]["ipc"][
|
||||
"current_id"] = current_id
|
||||
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
|
||||
self.send_splitwise_tasks_innode([task], addr)
|
||||
|
||||
else:
|
||||
|
||||
addr = f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"\
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
addr = (
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info[
|
||||
"cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"][
|
||||
"current_id"] = current_id
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
|
||||
self._send_message(addr, "prefill", [task])
|
||||
task.disaggregate_info["cache_info"] = decode_diagg
|
||||
task.disaggregate_info["role"] = "prefill"
|
||||
@@ -288,10 +285,8 @@ class SplitwiseConnector:
|
||||
if port not in self.connect_innode_instances:
|
||||
self.create_connection(port)
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"][
|
||||
"port"] = self.cfg.engine_worker_queue_port
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(
|
||||
("decode", tasks))
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
logger.info(f"send splitwise tasks to port {port} decode")
|
||||
@@ -309,8 +304,7 @@ class SplitwiseConnector:
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
self.create_connection(port)
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(
|
||||
("decode", tasks_list))
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
logger.info(f"send first token to port {node} decode")
|
||||
@@ -326,18 +320,19 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port] = EngineWorkerQueue(
|
||||
address=("0.0.0.0", int(port)),
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
client_id=0)
|
||||
client_id=0,
|
||||
)
|
||||
|
||||
def send_cache_infos(self, tasks, current_id):
|
||||
"""
|
||||
Send cache information to specific port.
|
||||
Send cache information to specific port.
|
||||
|
||||
Parameters:
|
||||
tasks (list): List of tasks.
|
||||
current_id (int): Current id to indicate the prefill number.
|
||||
Parameters:
|
||||
tasks (list): List of tasks.
|
||||
current_id (int): Current id to indicate the prefill number.
|
||||
|
||||
Returns:
|
||||
bool: Whether it is in decode status.
|
||||
Returns:
|
||||
bool: Whether it is in decode status.
|
||||
"""
|
||||
is_decode = False
|
||||
temp_cache_info = dict()
|
||||
@@ -348,38 +343,26 @@ class SplitwiseConnector:
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||
cache_info = {
|
||||
"request_id":
|
||||
tasks[i].request_id,
|
||||
"device_ids":
|
||||
self.cfg.device_ids.split(","),
|
||||
"transfer_protocol":
|
||||
"ipc",
|
||||
"dest_block_ids":
|
||||
tasks[i].disaggregate_info["block_tables"],
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"transfer_protocol": "ipc",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].disaggregate_info["cache_info"]["ipc"][
|
||||
"port"] not in temp_cache_info:
|
||||
temp_cache_info[tasks[i].disaggregate_info[
|
||||
"cache_info"]["ipc"]["port"]] = []
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]
|
||||
["ipc"]["port"]].append(cache_info)
|
||||
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
|
||||
else:
|
||||
addr = f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:" + \
|
||||
f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
||||
addr = (
|
||||
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
cache_info = {
|
||||
"request_id":
|
||||
tasks[i].request_id,
|
||||
"device_ids":
|
||||
self.cfg.device_ids.split(","),
|
||||
"ip":
|
||||
self.cfg.host_ip,
|
||||
"rdma_ports":
|
||||
self.cfg.disaggregate_info["cache_info"]["rdma"]
|
||||
["rdma_port"],
|
||||
"transfer_protocol":
|
||||
"rdma",
|
||||
"dest_block_ids":
|
||||
tasks[i].disaggregate_info["block_tables"],
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if addr not in temp_cache_info:
|
||||
temp_cache_info[addr] = []
|
||||
@@ -390,7 +373,7 @@ class SplitwiseConnector:
|
||||
else:
|
||||
addr = "prefill"
|
||||
if current_id == -1:
|
||||
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]['current_id']
|
||||
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
@@ -423,16 +406,13 @@ class SplitwiseConnector:
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
json_data = json.dumps({
|
||||
"type": msg_type,
|
||||
"payload": payload
|
||||
}).encode('utf-8')
|
||||
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
|
||||
return json_data
|
||||
|
||||
def _deserialize_message(self, data: bytes):
|
||||
|
||||
# JSON反序列化
|
||||
message = json.loads(data.decode('utf-8'))
|
||||
message = json.loads(data.decode("utf-8"))
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
@@ -461,8 +441,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
self.engine_worker_queue.put_disaggregated_tasks(
|
||||
("decode", tasks_data))
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
|
||||
|
||||
def _handle_decode(self, payload):
|
||||
"""
|
||||
@@ -471,11 +450,14 @@ class SplitwiseConnector:
|
||||
tasks = []
|
||||
for task in payload:
|
||||
tasks.append(
|
||||
RequestOutput(request_id=task["request_id"],
|
||||
outputs=CompletionOutput(
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
),
|
||||
finished=True))
|
||||
RequestOutput(
|
||||
request_id=task["request_id"],
|
||||
outputs=CompletionOutput(
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
),
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
|
Reference in New Issue
Block a user