polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -68,8 +68,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(
f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.poller = zmq.Poller()
@@ -177,8 +176,7 @@ class SplitwiseConnector:
for port in self.cfg.innode_prefill_ports:
if port not in self.connect_innode_instances:
self.create_connection(port)
if self.connect_innode_instances[
port].available_prefill_instances.qsize() > 0:
if self.connect_innode_instances[port].available_prefill_instances.qsize() > 0:
return False
return True
@@ -199,15 +197,15 @@ class SplitwiseConnector:
if self.connect_innode_instances[port].get_prefill_instances() == 1:
for task in tasks:
task.disaggregate_info = {
"role": "prefill",
"role": "prefill",
"transfer_protocol": "ipc",
"cache_info": {
"ipc": {
"ip": "0.0.0.0",
"port": self.cfg.engine_worker_queue_port,
"current_id": current_id
"current_id": current_id,
},
}
},
}
self.connect_innode_instances[port].put_disaggregated_tasks(("prefill", tasks))
current_port = port
@@ -229,9 +227,9 @@ class SplitwiseConnector:
"ipc": {
"ip": "0.0.0.0",
"port": current_port,
"current_id": current_id
"current_id": current_id,
},
}
},
}
def send_splitwise_tasks(self, tasks, current_id):
@@ -254,21 +252,20 @@ class SplitwiseConnector:
if task.disaggregate_info["transfer_protocol"] == "ipc":
addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
task.disaggregate_info["cache_info"]["ipc"][
"current_id"] = current_id
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
self.send_splitwise_tasks_innode([task], addr)
else:
addr = f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"\
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
addr = (
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info[
"cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"][
"current_id"] = current_id
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill"
@@ -288,10 +285,8 @@ class SplitwiseConnector:
if port not in self.connect_innode_instances:
self.create_connection(port)
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"][
"port"] = self.cfg.engine_worker_queue_port
self.connect_innode_instances[port].put_disaggregated_tasks(
("decode", tasks))
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode")
@@ -309,8 +304,7 @@ class SplitwiseConnector:
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
self.create_connection(port)
self.connect_innode_instances[port].put_disaggregated_tasks(
("decode", tasks_list))
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode")
@@ -326,18 +320,19 @@ class SplitwiseConnector:
self.connect_innode_instances[port] = EngineWorkerQueue(
address=("0.0.0.0", int(port)),
num_client=self.cfg.tensor_parallel_size,
client_id=0)
client_id=0,
)
def send_cache_infos(self, tasks, current_id):
"""
Send cache information to specific port.
Send cache information to specific port.
Parameters:
tasks (list): List of tasks.
current_id (int): Current id to indicate the prefill number.
Parameters:
tasks (list): List of tasks.
current_id (int): Current id to indicate the prefill number.
Returns:
bool: Whether it is in decode status.
Returns:
bool: Whether it is in decode status.
"""
is_decode = False
temp_cache_info = dict()
@@ -348,38 +343,26 @@ class SplitwiseConnector:
if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
cache_info = {
"request_id":
tasks[i].request_id,
"device_ids":
self.cfg.device_ids.split(","),
"transfer_protocol":
"ipc",
"dest_block_ids":
tasks[i].disaggregate_info["block_tables"],
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"transfer_protocol": "ipc",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].disaggregate_info["cache_info"]["ipc"][
"port"] not in temp_cache_info:
temp_cache_info[tasks[i].disaggregate_info[
"cache_info"]["ipc"]["port"]] = []
temp_cache_info[tasks[i].disaggregate_info["cache_info"]
["ipc"]["port"]].append(cache_info)
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
else:
addr = f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:" + \
f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
addr = (
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
)
cache_info = {
"request_id":
tasks[i].request_id,
"device_ids":
self.cfg.device_ids.split(","),
"ip":
self.cfg.host_ip,
"rdma_ports":
self.cfg.disaggregate_info["cache_info"]["rdma"]
["rdma_port"],
"transfer_protocol":
"rdma",
"dest_block_ids":
tasks[i].disaggregate_info["block_tables"],
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if addr not in temp_cache_info:
temp_cache_info[addr] = []
@@ -390,7 +373,7 @@ class SplitwiseConnector:
else:
addr = "prefill"
if current_id == -1:
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]['current_id']
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
@@ -423,16 +406,13 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]
json_data = json.dumps({
"type": msg_type,
"payload": payload
}).encode('utf-8')
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
return json_data
def _deserialize_message(self, data: bytes):
# JSON反序列化
message = json.loads(data.decode('utf-8'))
message = json.loads(data.decode("utf-8"))
return message["type"], message["payload"]
def _process_message(self, message: bytes):
@@ -461,8 +441,7 @@ class SplitwiseConnector:
"""
tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(
("decode", tasks_data))
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
def _handle_decode(self, payload):
"""
@@ -471,11 +450,14 @@ class SplitwiseConnector:
tasks = []
for task in payload:
tasks.append(
RequestOutput(request_id=task["request_id"],
outputs=CompletionOutput(
index=task["outputs"]["index"],
send_idx=0,
token_ids=task["outputs"]["token_ids"],
),
finished=True))
RequestOutput(
request_id=task["request_id"],
outputs=CompletionOutput(
index=task["outputs"]["index"],
send_idx=0,
token_ids=task["outputs"]["token_ids"],
),
finished=True,
)
)
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))