mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-27 02:20:31 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			401 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			401 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License"
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import threading
 | |
| import time
 | |
| from multiprocessing.managers import (
 | |
|     AcquirerProxy,
 | |
|     BaseManager,
 | |
|     ListProxy,
 | |
|     Value,
 | |
|     ValueProxy,
 | |
| )
 | |
| from queue import Queue
 | |
| from typing import Any, List, Tuple
 | |
| 
 | |
| import numpy as np
 | |
| 
 | |
| from fastdeploy.utils import llm_logger
 | |
| 
 | |
| 
 | |
| class EngineWorkerQueue:
 | |
|     """
 | |
|     Cross-machine and cross-process communication queue between Engine and Worker.
 | |
|     Manages shared resources using multiprocessing managers for inter-process communication.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         address: Tuple[str, int] = ("0.0.0.0", 5000),
 | |
|         authkey: bytes = b"secret_key",
 | |
|         is_server: bool = False,
 | |
|         num_client: int = 1,  # tensor parallel size
 | |
|         client_id: int = -1,  # tensor parallel id
 | |
|         local_data_parallel_size: int = 1,  # data parallel size
 | |
|         local_data_parallel_id: int = 0,  # local data parallel id
 | |
|     ) -> None:
 | |
|         """
 | |
|         Initialize the communication queue.
 | |
| 
 | |
|         Args:
 | |
|             address: Network address (IP, port) for the queue server
 | |
|             authkey: Authentication key for secure connection
 | |
|             is_server: Whether this instance acts as a server
 | |
|             num_client: Total number of expected clients
 | |
|             client_id: Unique identifier for client instances
 | |
|         """
 | |
|         self.address: Tuple[str, int] = address
 | |
|         self.authkey: bytes = authkey
 | |
|         self.is_server: bool = is_server
 | |
|         self.num_client: int = num_client
 | |
|         self.client_id: int = client_id
 | |
|         self.local_data_parallel_size = local_data_parallel_size
 | |
|         self.local_data_parallel_id = local_data_parallel_id
 | |
| 
 | |
|         class QueueManager(BaseManager):
 | |
|             """
 | |
|             Custom QueueManager for proxy object registration.
 | |
|             """
 | |
| 
 | |
|             pass
 | |
| 
 | |
|         if is_server:
 | |
|             # Server-side initialization for shared resources
 | |
|             self.tasks_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
 | |
|             self.client_read_flag_init: List[List[int]] = [
 | |
|                 [1] * self.num_client for _ in range(self.local_data_parallel_size)
 | |
|             ]
 | |
|             self.lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
 | |
|             self.read_finish_flag_init: List[Value] = [Value("i", 0) for _ in range(self.local_data_parallel_size)]
 | |
|             self.connected_client_counter_init: List[Value] = [
 | |
|                 Value("i", 0) for _ in range(self.local_data_parallel_size)
 | |
|             ]
 | |
|             self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
 | |
|             self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
 | |
|             self.client_read_info_flag_init: List[List[int]] = [
 | |
|                 [1] * self.num_client for _ in range(self.local_data_parallel_size)
 | |
|             ]
 | |
|             self.lock_info_init: List[threading.Lock] = [
 | |
|                 threading.Lock() for _ in range(self.local_data_parallel_size)
 | |
|             ]
 | |
| 
 | |
|             self.finish_request_barrier = [
 | |
|                 threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
 | |
|             ]
 | |
| 
 | |
|             # Register shared objects with proxy types
 | |
|             QueueManager.register(
 | |
|                 "get_tasks",
 | |
|                 callable=lambda idx: self.tasks_init[idx],
 | |
|                 proxytype=ListProxy,
 | |
|             )
 | |
|             QueueManager.register(
 | |
|                 "get_client_read_flag",
 | |
|                 callable=lambda idx: self.client_read_flag_init[idx],
 | |
|                 proxytype=ListProxy,
 | |
|             )
 | |
|             QueueManager.register(
 | |
|                 "get_lock",
 | |
|                 callable=lambda idx: self.lock_init[idx],
 | |
|                 proxytype=AcquirerProxy,
 | |
|             )
 | |
|             QueueManager.register(
 | |
|                 "get_read_finish_flag",
 | |
|                 callable=lambda idx: self.read_finish_flag_init[idx],
 | |
|                 proxytype=ValueProxy,
 | |
|             )
 | |
|             QueueManager.register(
 | |
|                 "get_connected_client_counter",
 | |
|                 callable=lambda idx: self.connected_client_counter_init[idx],
 | |
|                 proxytype=ValueProxy,
 | |
|             )
 | |
| 
 | |
|             QueueManager.register(
 | |
|                 "get_finish_request_queue",
 | |
|                 callable=lambda idx: self.finished_req_queue[idx],
 | |
|             )
 | |
| 
 | |
|             QueueManager.register(
 | |
|                 "get_cache_infos",
 | |
|                 callable=lambda idx: self.cache_infos_init[idx],
 | |
|                 proxytype=ListProxy,
 | |
|             )
 | |
| 
 | |
|             QueueManager.register(
 | |
|                 "get_client_read_info_flag",
 | |
|                 callable=lambda idx: self.client_read_info_flag_init[idx],
 | |
|                 proxytype=ListProxy,
 | |
|             )
 | |
|             QueueManager.register(
 | |
|                 "get_lock_info",
 | |
|                 callable=lambda idx: self.lock_info_init[idx],
 | |
|                 proxytype=AcquirerProxy,
 | |
|             )
 | |
| 
 | |
|             self.disaggregate_requests = [Queue() for _ in range(self.local_data_parallel_size)]
 | |
|             QueueManager.register(
 | |
|                 "get_disaggregate_requests",
 | |
|                 callable=lambda idx: self.disaggregate_requests[idx],
 | |
|             )
 | |
| 
 | |
|             self.available_prefill_instances = Queue()
 | |
|             QueueManager.register(
 | |
|                 "get_available_prefill_instances",
 | |
|                 callable=lambda: self.available_prefill_instances,
 | |
|             )
 | |
| 
 | |
|             QueueManager.register(
 | |
|                 "get_finish_request_barrier",
 | |
|                 callable=lambda idx: self.finish_request_barrier[idx],
 | |
|             )
 | |
|             self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
 | |
|             self.manager.start()
 | |
|         else:
 | |
|             # Client-side connection setup
 | |
|             assert (
 | |
|                 self.client_id >= 0 and self.client_id < self.num_client
 | |
|             ), f"self.client_id={self.client_id}, self.num_client={self.num_client}"
 | |
|             QueueManager.register("get_tasks")
 | |
|             QueueManager.register("get_client_read_flag")
 | |
|             QueueManager.register("get_lock")
 | |
|             QueueManager.register("get_read_finish_flag")
 | |
|             QueueManager.register("get_connected_client_counter")
 | |
|             QueueManager.register("get_finish_request_queue")
 | |
|             QueueManager.register("get_cache_infos")
 | |
|             QueueManager.register("get_client_read_info_flag")
 | |
|             QueueManager.register("get_lock_info")
 | |
|             QueueManager.register("get_disaggregate_requests")
 | |
|             QueueManager.register("get_available_prefill_instances")
 | |
|             QueueManager.register("get_finish_request_barrier")
 | |
|             self.manager = QueueManager(address=self.address, authkey=self.authkey)
 | |
|             self._connect_with_retry()
 | |
| 
 | |
|             # Get proxy objects for shared resources
 | |
|             self.tasks: ListProxy = self.manager.get_tasks(self.local_data_parallel_id)
 | |
|             self.client_read_flag: ListProxy = self.manager.get_client_read_flag(self.local_data_parallel_id)
 | |
|             self.lock: AcquirerProxy = self.manager.get_lock(self.local_data_parallel_id)
 | |
|             self.read_finish_flag: ValueProxy = self.manager.get_read_finish_flag(self.local_data_parallel_id)
 | |
|             self.connected_client_counter: ValueProxy = self.manager.get_connected_client_counter(
 | |
|                 self.local_data_parallel_id
 | |
|             )
 | |
|             self.cache_infos: ListProxy = self.manager.get_cache_infos(self.local_data_parallel_id)
 | |
|             self.client_read_info_flag: ListProxy = self.manager.get_client_read_info_flag(self.local_data_parallel_id)
 | |
|             self.lock_info: AcquirerProxy = self.manager.get_lock_info(self.local_data_parallel_id)
 | |
| 
 | |
|             # p/d 分离获取
 | |
|             self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
 | |
|             self.available_prefill_instances = self.manager.get_available_prefill_instances()
 | |
|             self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
 | |
|             self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
 | |
|             assert self.num_client == len(self.client_read_flag)
 | |
| 
 | |
|         if is_server:
 | |
|             llm_logger.info("EngineWorkerQueue server started.")
 | |
|         else:
 | |
|             # Update client connection counter
 | |
|             self.lock.acquire()
 | |
|             self.connected_client_counter.set(self.connected_client_counter.get() + 1)
 | |
|             self.lock.release()
 | |
|             llm_logger.info(
 | |
|                 f"Connected EngineWorkerQueue client_id: {self.client_id}, number "
 | |
|                 f"of connected clients: {self.connected_client_counter.get()}"
 | |
|             )
 | |
| 
 | |
|     def _connect_with_retry(self, max_retries: int = 5, interval: int = 3) -> None:
 | |
|         """
 | |
|         Connect to the server with retry mechanism.
 | |
| 
 | |
|         Args:
 | |
|             max_retries: Maximum connection attempts
 | |
|             interval: Retry interval in seconds
 | |
| 
 | |
|         Raises:
 | |
|             ConnectionError: If all connection attempts fail
 | |
|         """
 | |
|         for _ in range(max_retries):
 | |
|             try:
 | |
|                 self.manager.connect()
 | |
|                 return
 | |
|             except ConnectionRefusedError:
 | |
|                 time.sleep(interval)
 | |
|         raise ConnectionError(f"TaskQueue cannot connect {self.address}")
 | |
| 
 | |
|     def put_tasks(self, tasks: List[Any]) -> None:
 | |
|         """
 | |
|         Add tasks to the shared queue in a thread-safe manner.
 | |
|         Waits until all clients have read previous tasks before adding new ones.
 | |
| 
 | |
|         Args:
 | |
|             tasks: Tasks to be added to the queue
 | |
|         """
 | |
|         self.lock.acquire()
 | |
|         while sum(self.client_read_flag) < self.num_client:
 | |
|             self.lock.release()
 | |
|             time.sleep(0.001)
 | |
|             self.lock.acquire()
 | |
| 
 | |
|         self.tasks[:] = list()
 | |
|         self.client_read_flag[:] = [0] * self.num_client
 | |
|         self.tasks.append(tasks)
 | |
|         self.lock.release()
 | |
| 
 | |
|     def get_tasks(self) -> Tuple[List[Any], bool]:
 | |
|         """
 | |
|         Retrieve tasks from the shared queue and update read status.
 | |
| 
 | |
|         Returns:
 | |
|             tuple: (list of tasks, bool indicating if all clients have read)
 | |
|         """
 | |
|         tasks: List[Any] = list()
 | |
|         self.lock.acquire()
 | |
|         tasks.extend(self.tasks)
 | |
|         self.client_read_flag[self.client_id] = 1
 | |
|         all_client_read: bool = np.sum(self.client_read_flag) == self.num_client
 | |
|         if all_client_read:
 | |
|             self.tasks[:] = list()
 | |
|         self.lock.release()
 | |
|         return tasks, all_client_read
 | |
| 
 | |
|     def num_tasks(self) -> int:
 | |
|         """
 | |
|         Get current number of tasks in the queue.
 | |
| 
 | |
|         Returns:
 | |
|             int: Total number of tasks
 | |
|         """
 | |
|         self.lock.acquire()
 | |
|         total_num: int = len(self.tasks)
 | |
|         self.lock.release()
 | |
|         return total_num
 | |
| 
 | |
|     def get_prefill_instances(self):
 | |
|         """
 | |
|         check if the prefill queue is empty
 | |
|         """
 | |
|         if self.available_prefill_instances.qsize() == 0:
 | |
|             return 0
 | |
|         else:
 | |
|             return self.available_prefill_instances.get()
 | |
| 
 | |
|     def put_cache_info(self, cache_info) -> None:
 | |
|         """
 | |
|         Args:
 | |
|             tasks: Tasks to be added to the queue
 | |
|         """
 | |
|         self.lock_info.acquire()
 | |
|         while sum(self.client_read_info_flag) < self.num_client:
 | |
|             self.lock_info.release()
 | |
|             time.sleep(0.001)
 | |
|             self.lock_info.acquire()
 | |
| 
 | |
|         self.cache_infos[:] = list()
 | |
|         self.client_read_info_flag[:] = [0] * self.num_client
 | |
| 
 | |
|         self.cache_infos.extend(cache_info)
 | |
|         llm_logger.debug(f"cache_infos: {self.cache_infos}  local_data_parallel_id:{self.local_data_parallel_id}")
 | |
|         self.lock_info.release()
 | |
| 
 | |
|     def get_cache_info(self) -> List[Any]:
 | |
|         """
 | |
|         Retrieve tasks from the shared queue and update read status.
 | |
| 
 | |
|         Returns:
 | |
|             tuple: (list of tasks, bool indicating if all clients have read)
 | |
|         """
 | |
|         cache_infos: List[Any] = list()
 | |
|         self.lock_info.acquire()
 | |
|         if self.client_read_info_flag[self.client_id] == 1:
 | |
|             self.lock_info.release()
 | |
|             return cache_infos
 | |
|         cache_infos.extend(self.cache_infos)
 | |
|         self.client_read_info_flag[self.client_id] = 1
 | |
|         all_client_read: bool = np.sum(self.client_read_info_flag) == self.num_client
 | |
|         if all_client_read:
 | |
|             self.cache_infos[:] = list()
 | |
|         self.lock_info.release()
 | |
|         if len(cache_infos) != 0:
 | |
|             llm_logger.debug(f"get cache infos: {cache_infos}  local_data_parallel_id:{self.local_data_parallel_id}")
 | |
|         return cache_infos
 | |
| 
 | |
|     def num_cache_infos(self) -> int:
 | |
|         """
 | |
|         Get current number of tasks in the queue.
 | |
| 
 | |
|         Returns:
 | |
|             int: Total number of tasks
 | |
|         """
 | |
|         self.lock_info.acquire()
 | |
|         total_num: int = len(self.cache_infos)
 | |
|         self.lock_info.release()
 | |
|         return total_num
 | |
| 
 | |
|     def put_finished_req(self, req_ids) -> None:
 | |
|         """
 | |
|         Put finished request ID into the queue.
 | |
| 
 | |
|         Args:
 | |
|             req_ids: Request ID to be added to the queue
 | |
|         """
 | |
|         self.finished_req_queue.put(req_ids)
 | |
| 
 | |
|     def get_finished_req(self) -> str:
 | |
|         """
 | |
|         Get finished request ID from the queue.
 | |
| 
 | |
|         Returns:
 | |
|             str: Finished request ID
 | |
|         """
 | |
|         ans = []
 | |
|         if self.finished_req_queue.empty():
 | |
|             return ans
 | |
|         ans = self.finished_req_queue.get()
 | |
|         llm_logger.debug(f"get finished req: {ans}")
 | |
|         return ans
 | |
| 
 | |
|     def disaggregate_queue_empty(self):
 | |
|         """
 | |
|         Check if the disaggregated task queue is empty.
 | |
|         """
 | |
|         return self.disaggregate_requests.qsize() == 0
 | |
| 
 | |
|     def put_disaggregated_tasks(self, item):
 | |
|         """
 | |
|         put disaggregated tasks to the queue
 | |
|         """
 | |
|         llm_logger.debug("put item to queue")
 | |
|         self.disaggregate_requests.put(item)
 | |
|         llm_logger.debug("put item to queue success")
 | |
| 
 | |
|     def get_disaggregated_tasks(self):
 | |
|         """
 | |
|         get disaggregated tasks from the queue
 | |
|         """
 | |
|         llm_logger.debug("get tasks from queue")
 | |
|         if self.disaggregate_requests.qsize() == 0:
 | |
|             return None
 | |
|         item = []
 | |
|         while not self.disaggregate_requests.empty():
 | |
|             item.append(self.disaggregate_requests.get())
 | |
|         llm_logger.debug("get tasks from queue success")
 | |
|         return item
 | |
| 
 | |
|     def cleanup(self):
 | |
|         """
 | |
|         Exit the worker queue gracefully.
 | |
|         """
 | |
|         if self.manager is not None and self.is_server:
 | |
|             self.manager.shutdown()
 | 
