mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support ep pd with external module (#3194)
* Support external module * Support external module * Support external module * Support external module * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix bug * fix bug * fix bug * merge --------- Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
This commit is contained in:
@@ -85,12 +85,15 @@ class EngineWorkerQueue:
|
||||
]
|
||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.client_read_info_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.lock_info_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
|
||||
|
||||
self.finish_request_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
@@ -112,11 +115,26 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_lock",
|
||||
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_read_finish_flag",
|
||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks",
|
||||
callable=lambda idx: self.connect_rdma_tasks_list[idx],
|
||||
proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks_responses",
|
||||
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
|
||||
proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connected_client_counter",
|
||||
callable=lambda idx: self.connected_client_counter_init[idx],
|
||||
@@ -180,6 +198,9 @@ class EngineWorkerQueue:
|
||||
QueueManager.register("get_disaggregate_requests")
|
||||
QueueManager.register("get_available_prefill_instances")
|
||||
QueueManager.register("get_finish_request_barrier")
|
||||
QueueManager.register("get_connect_rdma_tasks")
|
||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||
QueueManager.register("get_connect_task_lock")
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
|
||||
@@ -200,6 +221,13 @@ class EngineWorkerQueue:
|
||||
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
||||
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
# p/d互联
|
||||
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||
|
||||
assert self.num_client == len(self.client_read_flag)
|
||||
|
||||
if is_server:
|
||||
@@ -280,6 +308,45 @@ class EngineWorkerQueue:
|
||||
total_num: int = len(self.tasks)
|
||||
self.lock.release()
|
||||
return total_num
|
||||
|
||||
def put_connect_rdma_task(self, connect_rdma_task):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_response_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_response_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
|
||||
def get_prefill_instances(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user