mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Bug fix] Test td cache messager (#3242)
* support disable cache task in decode node * fix busg * Update engine.py * Update expert_service.py * Update splitwise_connector.py * Optimize log for debug * Optimize log for debug * fix bug --------- Co-authored-by: ltd0924 <ltd0924@sina.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
@@ -26,8 +26,6 @@ from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
|
||||
|
||||
class SplitwiseConnector:
|
||||
"""
|
||||
@@ -45,6 +43,12 @@ class SplitwiseConnector:
|
||||
resource_manager (object): Resource manager object.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.logger = get_logger(
|
||||
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
self.scheduler = scheduler
|
||||
self.engine_worker_queue = worker_queue
|
||||
self.resource_manager = resource_manager
|
||||
@@ -52,6 +56,7 @@ class SplitwiseConnector:
|
||||
self.temp_cache_info = dict()
|
||||
self.current_request_ids = dict()
|
||||
self.splitwise_queue = splitwise_queue
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
self.zmq_ctx = zmq.Context()
|
||||
@@ -70,7 +75,7 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
@@ -89,16 +94,16 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
logger.debug(f"receive {socks}")
|
||||
self.logger.debug(f"receive {socks}")
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
logger.debug(f"frames: {frames}")
|
||||
self.logger.debug(f"frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
time.sleep(0.001)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Receiver error: {e}")
|
||||
self.logger.error(f"Receiver error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_push_socket(self, addr):
|
||||
@@ -110,7 +115,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
try:
|
||||
logger.info(f"Establishing new connection to {addr}")
|
||||
self.logger.info(f"Establishing new connection to {addr}")
|
||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||
|
||||
# 设置连接参数
|
||||
@@ -129,7 +134,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"Connection to {addr} failed: {e}")
|
||||
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||
|
||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||
|
||||
@@ -138,7 +143,7 @@ class SplitwiseConnector:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
message = self._serialize_message(msg_type, payload)
|
||||
|
||||
try:
|
||||
@@ -146,18 +151,18 @@ class SplitwiseConnector:
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
logger.warning(f"Connection to {addr} not established")
|
||||
self.logger.warning(f"Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
logger.warning(f"Send queue full for {addr}")
|
||||
self.logger.warning(f"Send queue full for {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Send to {addr} failed: {e}")
|
||||
self.logger.error(f"Send to {addr} failed: {e}")
|
||||
self._close_connection(addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message preparation failed: {e}")
|
||||
self.logger.error(f"Message preparation failed: {e}")
|
||||
|
||||
def _close_connection(self, addr):
|
||||
"""
|
||||
@@ -262,7 +267,7 @@ class SplitwiseConnector:
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
@@ -290,7 +295,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
logger.info(f"send splitwise tasks to port {port} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||
current_port = port
|
||||
return current_port
|
||||
|
||||
@@ -300,7 +305,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
logger.info("send first token to port decode")
|
||||
self.logger.info("send first token to port decode")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
@@ -308,7 +313,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
logger.info(f"send first token to port {node} decode")
|
||||
self.logger.info(f"send first token to port {node} decode")
|
||||
self._send_message(node, "decode", tasks_list)
|
||||
|
||||
def create_connection(self, port):
|
||||
@@ -324,6 +329,22 @@ class SplitwiseConnector:
|
||||
client_id=0,
|
||||
)
|
||||
|
||||
def check_decode_allocated(self, task):
|
||||
if task.disaggregate_info is None:
|
||||
return True, ""
|
||||
if self.enable_decode_cache_task:
|
||||
return True, ""
|
||||
if task.disaggregate_info["role"] != "prefill":
|
||||
return True, ""
|
||||
while self.current_request_ids[task.request_id] == "init":
|
||||
time.sleep(0.001)
|
||||
msg = self.current_request_ids[task.request_id]
|
||||
del self.current_request_ids[task.request_id]
|
||||
if msg == "finished":
|
||||
return True, ""
|
||||
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_infos(self, tasks, current_id):
|
||||
"""
|
||||
Send cache information to specific port.
|
||||
@@ -340,15 +361,21 @@ class SplitwiseConnector:
|
||||
for i in range(len(tasks)):
|
||||
if tasks[i].disaggregate_info is None:
|
||||
continue
|
||||
logger.info(f"{tasks[i].disaggregate_info}")
|
||||
self.logger.info(f"{tasks[i].disaggregate_info}")
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"transfer_protocol": "ipc",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"transfer_protocol": "ipc",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
|
||||
@@ -357,14 +384,20 @@ class SplitwiseConnector:
|
||||
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if addr not in temp_cache_info:
|
||||
temp_cache_info[addr] = []
|
||||
|
||||
@@ -391,7 +424,7 @@ class SplitwiseConnector:
|
||||
else:
|
||||
if len(temp_cache_info):
|
||||
for k, v in temp_cache_info.items():
|
||||
logger.info(f"{k} {v}")
|
||||
self.logger.info(f"{k} {v}")
|
||||
if ":" in str(k):
|
||||
self._send_message(k, "cache_sync", v)
|
||||
else:
|
||||
@@ -408,7 +441,7 @@ class SplitwiseConnector:
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
logger.info(f"send message {msg_type} {req_ids}")
|
||||
self.logger.info(f"send message {msg_type} {req_ids}")
|
||||
|
||||
json_data = msgpack.packb({"type": msg_type, "payload": payload})
|
||||
|
||||
@@ -419,7 +452,7 @@ class SplitwiseConnector:
|
||||
# JSON反序列化
|
||||
message = msgpack.unpackb(data)
|
||||
req_ids = [task["request_id"] for task in message["payload"]]
|
||||
logger.info(f"send message {message['type']} {req_ids}")
|
||||
self.logger.info(f"recv message type {message['type']} for {req_ids}")
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
@@ -428,7 +461,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
logger.info(f"{msg_type}")
|
||||
self.logger.info(f"{msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
self._handle_prefill(payload)
|
||||
@@ -436,11 +469,16 @@ class SplitwiseConnector:
|
||||
self._handle_decode(payload)
|
||||
elif msg_type == "cache_sync":
|
||||
for task in payload:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
self.logger.info(f"cache_sync task: {task}")
|
||||
current_status = task.get("error_msg", "finished")
|
||||
self.current_request_ids[task["request_id"]] = current_status
|
||||
if self.enable_decode_cache_task:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
if current_status == "finished":
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message processing failed: {e}")
|
||||
self.logger.error(f"Message processing failed: {e}")
|
||||
|
||||
def _handle_prefill(self, tasks):
|
||||
"""
|
||||
@@ -450,7 +488,7 @@ class SplitwiseConnector:
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
req_ids = [task["request_id"] for task in tasks]
|
||||
self.splitwise_queue.append(("decode", tasks_data))
|
||||
logger.debug(f"{req_ids} received prefill data")
|
||||
self.logger.info(f"{req_ids} received prefill data")
|
||||
|
||||
def _handle_decode(self, payload):
|
||||
"""
|
||||
@@ -471,4 +509,4 @@ class SplitwiseConnector:
|
||||
)
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
self.splitwise_queue.append(("decode", tasks))
|
||||
logger.debug(f"{req_ids} received decode data")
|
||||
self.logger.info(f"{req_ids} received decode data")
|
||||
|
Reference in New Issue
Block a user