[Feature] Support pd ep deployment with yiyan adapter (#4029)

* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* Support pd ep deployment with yiyan adapter

* Support pd ep deployment with yiyan adapter

* refactor cache messager

* support scheduler v1 in PD

* suppport pd v1 + chunk prefill

* suppport pd v1 + chunk prefill

* add eplb

* support eplb

* support eplb

* support eplb

* support v1

* fix

* fix

* fix bug

* remove eplb support

* support prefix cache in P

* fix bug

* fix bug

* support one stop in V1

* fix bug

* fix ci

* fix ci

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-22 16:41:38 +08:00
committed by GitHub
parent 9845f0d010
commit 918ccdb123
22 changed files with 1838 additions and 343 deletions

View File

@@ -84,18 +84,28 @@ class EngineWorkerQueue:
Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.connect_task_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.finish_request_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.finish_add_cache_task_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types
QueueManager.register(
"get_tasks",
@@ -117,6 +127,19 @@ class EngineWorkerQueue:
callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy,
)
QueueManager.register(
"get_connect_task_lock",
callable=lambda idx: self.connect_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
)
QueueManager.register(
"get_connect_rdma_tasks_responses",
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_connected_client_counter",
callable=lambda idx: self.connected_client_counter_init[idx],
@@ -128,6 +151,11 @@ class EngineWorkerQueue:
callable=lambda idx: self.finished_req_queue[idx],
)
QueueManager.register(
"get_finish_add_cache_task_queue",
callable=lambda idx: self.finished_add_cache_task_queue[idx],
)
QueueManager.register(
"get_cache_infos",
callable=lambda idx: self.cache_infos_init[idx],
@@ -161,6 +189,10 @@ class EngineWorkerQueue:
"get_finish_request_barrier",
callable=lambda idx: self.finish_request_barrier[idx],
)
QueueManager.register(
"get_finish_add_cache_task_barrier",
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
)
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
self.manager.start()
else:
@@ -174,12 +206,17 @@ class EngineWorkerQueue:
QueueManager.register("get_read_finish_flag")
QueueManager.register("get_connected_client_counter")
QueueManager.register("get_finish_request_queue")
QueueManager.register("get_finish_add_cache_task_queue")
QueueManager.register("get_cache_infos")
QueueManager.register("get_client_read_info_flag")
QueueManager.register("get_lock_info")
QueueManager.register("get_disaggregate_requests")
QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier")
QueueManager.register("get_finish_add_cache_task_barrier")
QueueManager.register("get_connect_rdma_tasks")
QueueManager.register("get_connect_rdma_tasks_responses")
QueueManager.register("get_connect_task_lock")
self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry()
@@ -199,7 +236,20 @@ class EngineWorkerQueue:
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
self.available_prefill_instances = self.manager.get_available_prefill_instances()
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
self.local_data_parallel_id
)
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
self.local_data_parallel_id
)
# p/d互联
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
self.local_data_parallel_id
)
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
assert self.num_client == len(self.client_read_flag)
if is_server:
@@ -281,6 +331,44 @@ class EngineWorkerQueue:
self.lock.release()
return total_num
def put_connect_rdma_task(self, connect_rdma_task):
self.connect_task_lock.acquire()
self.connect_rdma_task_queue.append(connect_rdma_task)
self.connect_task_lock.release()
def get_connect_rdma_task(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def put_connect_rdma_task_response(self, connect_rdma_task_response):
self.connect_task_lock.acquire()
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
self.connect_task_lock.release()
def get_connect_rdma_task_response(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_response_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_response_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def get_prefill_instances(self):
"""
check if the prefill queue is empty
@@ -365,6 +453,29 @@ class EngineWorkerQueue:
llm_logger.debug(f"get finished req: {ans}")
return ans
def put_finished_add_cache_task_req(self, req_ids) -> None:
"""
Put finished request ID into the queue.
Args:
req_ids: Request ID to be added to the queue
"""
self.finished_add_cache_task_queue.put(req_ids)
def get_finished_add_cache_task_req(self) -> str:
"""
Get finished request ID from the queue.
Returns:
str: Finished request ID
"""
ans = []
if self.finished_add_cache_task_queue.empty():
return ans
ans = self.finished_add_cache_task_queue.get()
llm_logger.debug(f"get finished req: {ans}")
return ans
def disaggregate_queue_empty(self):
"""
Check if the disaggregated task queue is empty.