mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -84,18 +84,28 @@ class EngineWorkerQueue:
|
||||
Value("i", 0) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.client_read_info_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.lock_info_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.connect_task_lock_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.finish_request_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.finish_add_cache_task_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
# Register shared objects with proxy types
|
||||
QueueManager.register(
|
||||
"get_tasks",
|
||||
@@ -117,6 +127,19 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_lock",
|
||||
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks_responses",
|
||||
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connected_client_counter",
|
||||
callable=lambda idx: self.connected_client_counter_init[idx],
|
||||
@@ -128,6 +151,11 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.finished_req_queue[idx],
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_queue",
|
||||
callable=lambda idx: self.finished_add_cache_task_queue[idx],
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_cache_infos",
|
||||
callable=lambda idx: self.cache_infos_init[idx],
|
||||
@@ -161,6 +189,10 @@ class EngineWorkerQueue:
|
||||
"get_finish_request_barrier",
|
||||
callable=lambda idx: self.finish_request_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_barrier",
|
||||
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
|
||||
)
|
||||
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self.manager.start()
|
||||
else:
|
||||
@@ -174,12 +206,17 @@ class EngineWorkerQueue:
|
||||
QueueManager.register("get_read_finish_flag")
|
||||
QueueManager.register("get_connected_client_counter")
|
||||
QueueManager.register("get_finish_request_queue")
|
||||
QueueManager.register("get_finish_add_cache_task_queue")
|
||||
QueueManager.register("get_cache_infos")
|
||||
QueueManager.register("get_client_read_info_flag")
|
||||
QueueManager.register("get_lock_info")
|
||||
QueueManager.register("get_disaggregate_requests")
|
||||
QueueManager.register("get_available_prefill_instances")
|
||||
QueueManager.register("get_finish_request_barrier")
|
||||
QueueManager.register("get_finish_add_cache_task_barrier")
|
||||
QueueManager.register("get_connect_rdma_tasks")
|
||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||
QueueManager.register("get_connect_task_lock")
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
|
||||
@@ -199,7 +236,20 @@ class EngineWorkerQueue:
|
||||
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
|
||||
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
||||
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
||||
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
# p/d互联
|
||||
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||
|
||||
assert self.num_client == len(self.client_read_flag)
|
||||
|
||||
if is_server:
|
||||
@@ -281,6 +331,44 @@ class EngineWorkerQueue:
|
||||
self.lock.release()
|
||||
return total_num
|
||||
|
||||
def put_connect_rdma_task(self, connect_rdma_task):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_response_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_response_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def get_prefill_instances(self):
|
||||
"""
|
||||
check if the prefill queue is empty
|
||||
@@ -365,6 +453,29 @@ class EngineWorkerQueue:
|
||||
llm_logger.debug(f"get finished req: {ans}")
|
||||
return ans
|
||||
|
||||
def put_finished_add_cache_task_req(self, req_ids) -> None:
|
||||
"""
|
||||
Put finished request ID into the queue.
|
||||
|
||||
Args:
|
||||
req_ids: Request ID to be added to the queue
|
||||
"""
|
||||
self.finished_add_cache_task_queue.put(req_ids)
|
||||
|
||||
def get_finished_add_cache_task_req(self) -> str:
|
||||
"""
|
||||
Get finished request ID from the queue.
|
||||
|
||||
Returns:
|
||||
str: Finished request ID
|
||||
"""
|
||||
ans = []
|
||||
if self.finished_add_cache_task_queue.empty():
|
||||
return ans
|
||||
ans = self.finished_add_cache_task_queue.get()
|
||||
llm_logger.debug(f"get finished req: {ans}")
|
||||
return ans
|
||||
|
||||
def disaggregate_queue_empty(self):
|
||||
"""
|
||||
Check if the disaggregated task queue is empty.
|
||||
|
Reference in New Issue
Block a user