mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
|
||||
|
||||
class SplitwiseConnector:
|
||||
"""
|
||||
@@ -46,12 +44,19 @@ class SplitwiseConnector:
|
||||
resource_manager (object): Resource manager object.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.logger = get_logger(
|
||||
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
self.engine_worker_queue = worker_queue
|
||||
self.resource_manager = resource_manager
|
||||
self.connect_innode_instances = {}
|
||||
self.temp_cache_info = dict()
|
||||
self.current_request_ids = dict()
|
||||
self.idx = self.cfg.parallel_config.local_data_parallel_id
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
self.zmq_ctx = zmq.Context()
|
||||
@@ -70,7 +75,7 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
@@ -90,17 +95,17 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
logger.debug(f"receive {socks}")
|
||||
self.logger.debug(f"receive {socks}")
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
logger.debug(f"frames: {frames}")
|
||||
self.logger.debug(f"frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_push_socket(self, addr):
|
||||
@@ -112,7 +117,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
try:
|
||||
logger.info(f"Establishing new connection to {addr}")
|
||||
self.logger.info(f"Establishing new connection to {addr}")
|
||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||
|
||||
# 设置连接参数
|
||||
@@ -131,7 +136,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"Connection to {addr} failed: {e}")
|
||||
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||
|
||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||
|
||||
@@ -140,7 +145,7 @@ class SplitwiseConnector:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
message = self._serialize_message(msg_type, payload)
|
||||
|
||||
try:
|
||||
@@ -148,19 +153,19 @@ class SplitwiseConnector:
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
logger.warning(f"Connection to {addr} not established")
|
||||
self.logger.warning(f"Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
logger.warning(f"Send queue full for {addr}")
|
||||
self.logger.warning(f"Send queue full for {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
main_process_metrics.send_cache_failed_num.inc()
|
||||
self._close_connection(addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message preparation failed: {e}")
|
||||
self.logger.error(f"Message preparation failed: {e}")
|
||||
|
||||
def _close_connection(self, addr):
|
||||
"""
|
||||
@@ -265,7 +270,7 @@ class SplitwiseConnector:
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
@@ -295,7 +300,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
logger.info(f"send splitwise tasks to port {port} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||
current_port = port
|
||||
return current_port
|
||||
|
||||
@@ -305,7 +310,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
logger.info("send first token to port decode")
|
||||
self.logger.info("send first token to port decode")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
@@ -313,7 +318,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
logger.info(f"send first token to port {node} decode")
|
||||
self.logger.info(f"send first token to port {node} decode")
|
||||
self._send_message(node, "decode", tasks_list)
|
||||
|
||||
def create_connection(self, port):
|
||||
@@ -329,6 +334,26 @@ class SplitwiseConnector:
|
||||
client_id=0,
|
||||
)
|
||||
|
||||
def check_decode_allocated(self, task):
|
||||
start_time = time.time()
|
||||
if task.disaggregate_info is None:
|
||||
return True, ""
|
||||
if self.enable_decode_cache_task:
|
||||
return True, ""
|
||||
if task.disaggregate_info["role"] != "prefill":
|
||||
return True, ""
|
||||
while self.current_request_ids[task.request_id] == "init":
|
||||
time.sleep(0.001)
|
||||
if time.time() - start_time > 30:
|
||||
del self.current_request_ids[task.request_id]
|
||||
return False, "timeout"
|
||||
msg = self.current_request_ids[task.request_id]
|
||||
del self.current_request_ids[task.request_id]
|
||||
if msg == "finished":
|
||||
return True, ""
|
||||
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_infos(self, tasks, current_id):
|
||||
"""
|
||||
Send cache information to specific port.
|
||||
@@ -345,7 +370,7 @@ class SplitwiseConnector:
|
||||
for i in range(len(tasks)):
|
||||
if tasks[i].disaggregate_info is None:
|
||||
continue
|
||||
logger.info(f"{tasks[i].disaggregate_info}")
|
||||
self.logger.info(f"{tasks[i].disaggregate_info}")
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||
cache_info = {
|
||||
@@ -380,11 +405,19 @@ class SplitwiseConnector:
|
||||
addr = "prefill"
|
||||
if current_id == -1:
|
||||
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": current_id,
|
||||
}
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": tasks[i].idx,
|
||||
"need_prefill_tokens": tasks[i].need_prefill_tokens,
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": current_id,
|
||||
}
|
||||
if addr not in temp_cache_info:
|
||||
temp_cache_info[addr] = []
|
||||
|
||||
@@ -396,7 +429,7 @@ class SplitwiseConnector:
|
||||
else:
|
||||
if len(temp_cache_info):
|
||||
for k, v in temp_cache_info.items():
|
||||
logger.info(f"{k} {v}")
|
||||
self.logger.info(f"{k} {v}")
|
||||
if ":" in str(k):
|
||||
self._send_message(k, "cache_sync", v)
|
||||
else:
|
||||
@@ -427,7 +460,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
logger.info(f"{msg_type}")
|
||||
self.logger.info(f"{msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
self._handle_prefill(payload)
|
||||
@@ -435,11 +468,16 @@ class SplitwiseConnector:
|
||||
self._handle_decode(payload)
|
||||
elif msg_type == "cache_sync":
|
||||
for task in payload:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
self.logger.info(f"cache_sync task: {task}")
|
||||
current_status = task.get("error_msg", "finished")
|
||||
self.current_request_ids[task["request_id"]] = current_status
|
||||
if self.enable_decode_cache_task:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
if current_status == "finished":
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def _handle_prefill(self, tasks):
|
||||
"""
|
||||
@@ -462,8 +500,12 @@ class SplitwiseConnector:
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||
),
|
||||
finished=True,
|
||||
num_cached_tokens=task["num_cached_tokens"],
|
||||
error_code=task["error_code"],
|
||||
error_msg=task["error_msg"],
|
||||
)
|
||||
)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
|
Reference in New Issue
Block a user