[Feature] Support pd ep deployment with yiyan adapter (#4029)

* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* Support pd ep deployment with yiyan adapter

* Support pd ep deployment with yiyan adapter

* refactor cache messager

* support scheduler v1 in PD

* suppport pd v1 + chunk prefill

* suppport pd v1 + chunk prefill

* add eplb

* support eplb

* support eplb

* support eplb

* support v1

* fix

* fix

* fix bug

* remove eplb support

* support prefix cache in P

* fix bug

* fix bug

* support one stop in V1

* fix bug

* fix ci

* fix ci

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-22 16:41:38 +08:00
committed by GitHub
parent 9845f0d010
commit 918ccdb123
22 changed files with 1838 additions and 343 deletions

View File

@@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
class SplitwiseConnector:
"""
@@ -46,12 +44,19 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object.
"""
self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager
self.connect_innode_instances = {}
self.temp_cache_info = dict()
self.current_request_ids = dict()
self.idx = self.cfg.parallel_config.local_data_parallel_id
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context()
@@ -70,7 +75,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
@@ -90,17 +95,17 @@ class SplitwiseConnector:
if not socks:
continue
else:
logger.debug(f"receive {socks}")
self.logger.debug(f"receive {socks}")
frames = self.router_socket.recv_multipart()
logger.debug(f"frames: {frames}")
self.logger.debug(f"frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
time.sleep(0.001)
else:
time.sleep(5)
except Exception as e:
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
time.sleep(1)
def _get_push_socket(self, addr):
@@ -112,7 +117,7 @@ class SplitwiseConnector:
return sock
try:
logger.info(f"Establishing new connection to {addr}")
self.logger.info(f"Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数
@@ -131,7 +136,7 @@ class SplitwiseConnector:
return sock
except zmq.ZMQError as e:
logger.error(f"Connection to {addr} failed: {e}")
self.logger.error(f"Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e
@@ -140,7 +145,7 @@ class SplitwiseConnector:
return
try:
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload)
try:
@@ -148,19 +153,19 @@ class SplitwiseConnector:
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
logger.warning(f"Connection to {addr} not established")
self.logger.warning(f"Connection to {addr} not established")
except zmq.Again:
logger.warning(f"Send queue full for {addr}")
self.logger.warning(f"Send queue full for {addr}")
except Exception as e:
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
main_process_metrics.send_cache_failed_num.inc()
self._close_connection(addr)
except Exception as e:
logger.error(f"Message preparation failed: {e}")
self.logger.error(f"Message preparation failed: {e}")
def _close_connection(self, addr):
"""
@@ -265,7 +270,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
logger.info(f"send splitwise tasks to port {addr} decode")
self.logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -295,7 +300,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode")
self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port
return current_port
@@ -305,7 +310,7 @@ class SplitwiseConnector:
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
logger.info("send first token to port decode")
self.logger.info("send first token to port decode")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
@@ -313,7 +318,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode")
self.logger.info(f"send first token to port {node} decode")
self._send_message(node, "decode", tasks_list)
def create_connection(self, port):
@@ -329,6 +334,26 @@ class SplitwiseConnector:
client_id=0,
)
def check_decode_allocated(self, task):
start_time = time.time()
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
if time.time() - start_time > 30:
del self.current_request_ids[task.request_id]
return False, "timeout"
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id):
"""
Send cache information to specific port.
@@ -345,7 +370,7 @@ class SplitwiseConnector:
for i in range(len(tasks)):
if tasks[i].disaggregate_info is None:
continue
logger.info(f"{tasks[i].disaggregate_info}")
self.logger.info(f"{tasks[i].disaggregate_info}")
if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
cache_info = {
@@ -380,11 +405,19 @@ class SplitwiseConnector:
addr = "prefill"
if current_id == -1:
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
"current_id": current_id,
}
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
"current_id": tasks[i].idx,
"need_prefill_tokens": tasks[i].need_prefill_tokens,
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
"current_id": current_id,
}
if addr not in temp_cache_info:
temp_cache_info[addr] = []
@@ -396,7 +429,7 @@ class SplitwiseConnector:
else:
if len(temp_cache_info):
for k, v in temp_cache_info.items():
logger.info(f"{k} {v}")
self.logger.info(f"{k} {v}")
if ":" in str(k):
self._send_message(k, "cache_sync", v)
else:
@@ -427,7 +460,7 @@ class SplitwiseConnector:
"""
try:
msg_type, payload = self._deserialize_message(message)
logger.info(f"{msg_type}")
self.logger.info(f"{msg_type}")
if msg_type == "prefill":
self._handle_prefill(payload)
@@ -435,11 +468,16 @@ class SplitwiseConnector:
self._handle_decode(payload)
elif msg_type == "cache_sync":
for task in payload:
del self.current_request_ids[task["request_id"]]
self.engine_worker_queue.put_cache_info(payload)
self.logger.info(f"cache_sync task: {task}")
current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
del self.current_request_ids[task["request_id"]]
if current_status == "finished":
self.engine_worker_queue.put_cache_info(payload)
except Exception as e:
logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
def _handle_prefill(self, tasks):
"""
@@ -462,8 +500,12 @@ class SplitwiseConnector:
index=task["outputs"]["index"],
send_idx=0,
token_ids=task["outputs"]["token_ids"],
draft_token_ids=task["outputs"]["draft_token_ids"],
),
finished=True,
num_cached_tokens=task["num_cached_tokens"],
error_code=task["error_code"],
error_msg=task["error_msg"],
)
)
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))