mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Feature] Support ep pd with external module (#3194)
* Support external module * Support external module * Support external module * Support external module * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix bug * fix bug * fix bug * merge --------- Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
This commit is contained in:
@@ -142,12 +142,16 @@ class CacheMessager:
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _prefill_layerwise_send_cache_thread(self):
|
||||
@@ -160,14 +164,14 @@ class CacheMessager:
|
||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
try:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=True,
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
@@ -175,14 +179,14 @@ class CacheMessager:
|
||||
)
|
||||
except:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=False,
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
@@ -310,3 +314,22 @@ class CacheMessager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
@@ -820,6 +820,7 @@ class EngineArgs:
|
||||
"max_num_partial_prefills",
|
||||
"max_long_partial_prefills",
|
||||
"long_prefill_token_threshold",
|
||||
"splitwise_role"
|
||||
]
|
||||
|
||||
all = asdict(self)
|
||||
|
@@ -47,12 +47,14 @@ from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
ZmqClient,
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
@@ -179,11 +181,64 @@ class LLMEngine:
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
|
||||
if api_server_pid is not None:
|
||||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||
self.external_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
|
||||
)
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
time.sleep(3)
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
request_queues_for_dp_ipc = (
|
||||
None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
|
||||
)
|
||||
result_queue_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_expert_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queue_for_dp_ipc,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
||||
if self.do_profile == 0 and (
|
||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
|
||||
):
|
||||
@@ -238,44 +293,11 @@ class LLMEngine:
|
||||
# 单机逻辑
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
self.split_mode_get_tasks()
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp":
|
||||
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
|
||||
self.splitwise_receive_thread.daemon = True
|
||||
self.splitwise_receive_thread.start()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_expert_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
||||
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
|
||||
return True
|
||||
|
||||
@@ -291,7 +313,7 @@ class LLMEngine:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
@@ -415,14 +437,18 @@ class LLMEngine:
|
||||
if self.api_server_pid is None:
|
||||
return
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if self.cfg.splitwise_role == "decode":
|
||||
return
|
||||
|
||||
added_requests: Dict[str, int] = dict()
|
||||
while self.running:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
if not self.cfg.enable_mm:
|
||||
err, data = self.zmq_server.receive_json_once(block)
|
||||
err, data = self.recv_request_server.receive_json_once(block)
|
||||
else:
|
||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
@@ -470,7 +496,7 @@ class LLMEngine:
|
||||
)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
self.zmq_server.send_multipart(request_id, error_result)
|
||||
self.send_response_server.send_response(request_id, error_result)
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
f"Error happend while receving new request from zmq, details={e}, "
|
||||
@@ -989,8 +1015,12 @@ class LLMEngine:
|
||||
print(f"Error extracting sub services: {e}")
|
||||
|
||||
self.engine_worker_queue.cleanup()
|
||||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||||
self.zmq_server.close()
|
||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||
self.send_response_server.close()
|
||||
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||
self.recv_request_server.close()
|
||||
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
||||
self.recv_control_cmd_server.close()
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
p.join()
|
||||
|
@@ -29,8 +29,9 @@ from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
|
||||
class ExpertService:
|
||||
@@ -60,6 +61,7 @@ class ExpertService:
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
@@ -111,8 +113,12 @@ class ExpertService:
|
||||
)
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
|
||||
|
||||
def start(self, ipc_signal_suffix, local_data_parallel_id):
|
||||
def start(
|
||||
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
If `api_server_pid` is defined, will launch a thread
|
||||
@@ -127,7 +133,7 @@ class ExpertService:
|
||||
cache_config=self.cfg.cache_config,
|
||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
||||
device_ids=self.cfg.local_device_ids,
|
||||
pod_ip=self.cfg.pod_ips[0],
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
|
||||
)
|
||||
@@ -147,6 +153,10 @@ class ExpertService:
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
|
||||
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
|
||||
elif self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
self.cfg.print()
|
||||
|
||||
@@ -356,13 +366,17 @@ class ExpertService:
|
||||
self.zmq_server.close()
|
||||
|
||||
|
||||
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
|
||||
def start_expert_service(
|
||||
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Start expert service
|
||||
"""
|
||||
expert_service = ExpertService(cfg, local_data_parallel_id)
|
||||
try:
|
||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
||||
expert_service.start(
|
||||
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
expert_service.split_connector.start_receiver()
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Expert service failed to start: {e}")
|
||||
|
@@ -71,6 +71,7 @@ class Request:
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True,
|
||||
trace_carrier: dict = dict(),
|
||||
dp_rank: Optional[int] = None
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@@ -119,6 +120,7 @@ class Request:
|
||||
self.task_type = RequestType.PREFILL
|
||||
self.idx = None
|
||||
self.need_prefill_tokens = self.prompt_token_ids_len
|
||||
self.dp_rank = dp_rank
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@@ -151,6 +153,7 @@ class Request:
|
||||
guided_json_object=d.get("guided_json_object", None),
|
||||
enable_thinking=d.get("enable_thinking", True),
|
||||
trace_carrier=d.get("trace_carrier", {}),
|
||||
dp_rank=d.get("dp_rank", None)
|
||||
)
|
||||
|
||||
@property
|
||||
|
@@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.config import ModelConfig
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -90,7 +90,7 @@ class EngineClient:
|
||||
"""
|
||||
Create a ZMQ client.
|
||||
"""
|
||||
self.zmq_client = ZmqClient(model, mode)
|
||||
self.zmq_client = ZmqIpcClient(model, mode)
|
||||
self.zmq_client.connect()
|
||||
|
||||
def format_and_add_data(self, prompts: dict):
|
||||
|
@@ -80,6 +80,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
|
||||
# enable kv cache block scheduler v1 (no need for kv_cache_ratio)
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
|
||||
# enable internal module to access LLMEngine.
|
||||
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
|
||||
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
|
||||
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
# Whether to use PLUGINS.
|
||||
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
|
||||
}
|
||||
|
@@ -17,6 +17,7 @@
|
||||
from .engine_cache_queue import EngineCacheQueue
|
||||
from .engine_worker_queue import EngineWorkerQueue
|
||||
from .ipc_signal import IPCSignal
|
||||
from .zmq_client import ZmqClient
|
||||
from .zmq_client import ZmqIpcClient
|
||||
from .zmq_server import ZmqIpcServer, ZmqTcpServer
|
||||
|
||||
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"]
|
||||
__all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"]
|
||||
|
@@ -85,12 +85,15 @@ class EngineWorkerQueue:
|
||||
]
|
||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.client_read_info_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.lock_info_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
|
||||
|
||||
self.finish_request_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
@@ -112,11 +115,26 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_lock",
|
||||
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_read_finish_flag",
|
||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks",
|
||||
callable=lambda idx: self.connect_rdma_tasks_list[idx],
|
||||
proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks_responses",
|
||||
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
|
||||
proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connected_client_counter",
|
||||
callable=lambda idx: self.connected_client_counter_init[idx],
|
||||
@@ -180,6 +198,9 @@ class EngineWorkerQueue:
|
||||
QueueManager.register("get_disaggregate_requests")
|
||||
QueueManager.register("get_available_prefill_instances")
|
||||
QueueManager.register("get_finish_request_barrier")
|
||||
QueueManager.register("get_connect_rdma_tasks")
|
||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||
QueueManager.register("get_connect_task_lock")
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
|
||||
@@ -200,6 +221,13 @@ class EngineWorkerQueue:
|
||||
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
||||
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
# p/d互联
|
||||
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||
|
||||
assert self.num_client == len(self.client_read_flag)
|
||||
|
||||
if is_server:
|
||||
@@ -281,6 +309,45 @@ class EngineWorkerQueue:
|
||||
self.lock.release()
|
||||
return total_num
|
||||
|
||||
def put_connect_rdma_task(self, connect_rdma_task):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_response_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_response_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
|
||||
def get_prefill_instances(self):
|
||||
"""
|
||||
check if the prefill queue is empty
|
||||
|
@@ -14,200 +14,78 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ZmqClient:
|
||||
class ZmqClientBase(ABC):
|
||||
"""
|
||||
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(mode)
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.router = None
|
||||
self.poller = None
|
||||
self.running = True
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
||||
def start_server(self):
|
||||
"""
|
||||
Start the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
||||
def create_router(self):
|
||||
"""
|
||||
Create a ROUTER socket and bind it to the specified router path.
|
||||
"""
|
||||
self.router = self.context.socket(zmq.ROUTER)
|
||||
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.router.bind(f"ipc://{self.router_path}")
|
||||
pass
|
||||
|
||||
def send_json(self, data):
|
||||
"""
|
||||
Send a JSON-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_json(data)
|
||||
|
||||
def recv_json(self):
|
||||
"""
|
||||
Receive a JSON-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_json()
|
||||
|
||||
def send_pyobj(self, data):
|
||||
"""
|
||||
Send a Pickle-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_pyobj(data)
|
||||
|
||||
def recv_pyobj(self):
|
||||
"""
|
||||
Receive a Pickle-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
"""
|
||||
if self.router is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
class ZmqIpcClient(ZmqClientBase):
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(self.mode)
|
||||
|
||||
while self.running:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
try:
|
||||
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.context = zmq.Context()
|
||||
return self.context.socket(self.mode)
|
||||
|
||||
try:
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_pyobj(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if hasattr(self, "socket") and not self.socket.closed:
|
||||
self.socket.close()
|
||||
|
||||
if self.router is not None and not self.router.closed:
|
||||
self.router.close()
|
||||
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
self._clear_ipc(self.file_name)
|
||||
self._clear_ipc(self.router_path)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
def connect(self):
|
||||
self._ensure_socket()
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
273
fastdeploy/inter_communicator/zmq_server.py
Normal file
273
fastdeploy/inter_communicator/zmq_server.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ZmqServerBase(ABC):
|
||||
"""
|
||||
ZmqServerBase
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_pyobj(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
|
||||
while self.running:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
try:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class ZmqIpcServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
if mode == zmq.PULL:
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
elif mode == zmq.ROUTER:
|
||||
self.file_name = f"/dev/shm/router_{name}.ipc"
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
return self.socket
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
self._clear_ipc(self.file_name)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
||||
|
||||
|
||||
class ZmqTcpServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"""
|
||||
|
||||
def __init__(self, port, mode):
|
||||
self.mode = mode
|
||||
self.port = port
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"tcp://*:{self.port}")
|
||||
return self.socket
|
||||
|
||||
def recv_control_cmd(self):
|
||||
"""
|
||||
Recieve control command from client
|
||||
"""
|
||||
self._ensure_socket()
|
||||
while self.running:
|
||||
try:
|
||||
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
task = msgpack.unpackb(task_data)
|
||||
task_id_str = task["task_id"]
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
with self.mutex:
|
||||
self.req_dict[task_id_str] = client
|
||||
return task
|
||||
|
||||
def response_for_control_cmd(self, task_id, result):
|
||||
"""
|
||||
Send command result back to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created.")
|
||||
try:
|
||||
result = msgpack.packb(result)
|
||||
self.socket.send_multipart([self.req_dict[task_id], b"", result])
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
with self.mutex:
|
||||
self.req_dict.pop(task_id, None)
|
||||
llm_logger.info(f"response control cmd finished, task_id: {task_id}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
@@ -412,7 +412,11 @@ class TokenProcessor:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
if (
|
||||
not is_prefill
|
||||
or self.cfg.scheduler_config.name == "splitwise"
|
||||
or self.cfg.scheduler_config.name == "dp"
|
||||
):
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
@@ -531,7 +535,11 @@ class TokenProcessor:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
if (
|
||||
not is_prefill
|
||||
or self.cfg.scheduler_config.name == "splitwise"
|
||||
or self.cfg.scheduler_config.name == "dp"
|
||||
):
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
|
@@ -18,6 +18,7 @@ import redis
|
||||
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
from .dp_scheduler import DPScheduler
|
||||
from .global_scheduler import GlobalScheduler
|
||||
from .local_scheduler import LocalScheduler
|
||||
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
|
||||
@@ -89,6 +90,57 @@ class LocalSchedulerConfig:
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
|
||||
class DPLocalSchedulerConfig(LocalSchedulerConfig):
|
||||
"""
|
||||
Configuration class for DPLocalScheduler.
|
||||
|
||||
Attributes:
|
||||
max_size: Maximum number of concurrent requests (-1 for unlimited)
|
||||
ttl: Time-to-live in seconds for request expiration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = -1,
|
||||
ttl: int = 900,
|
||||
max_model_len: int = 8192,
|
||||
enable_chunked_prefill: bool = False,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
splitwise_role: str = "prefill",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize LocalScheduler configuration.
|
||||
|
||||
Args:
|
||||
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
|
||||
ttl: Time-to-live in seconds for request expiration (default 900s)
|
||||
max_model_len: Maximum model context length in tokens
|
||||
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||
max_num_partial_prefills: Max partial prefill operations allowed
|
||||
max_long_partial_prefills: Max long-running partial prefill ops
|
||||
long_prefill_token_threshold: Token count threshold for long prefill
|
||||
**kwargs: Additional unused arguments (for forward compatibility)
|
||||
|
||||
Note:
|
||||
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
|
||||
- See LocalScheduler class for implementation details
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
self.splitwise_role = splitwise_role
|
||||
|
||||
|
||||
class GlobalSchedulerConfig:
|
||||
"""
|
||||
Configuration class for GlobalScheduler (Redis-based).
|
||||
@@ -229,6 +281,9 @@ class SchedulerConfig:
|
||||
if name == "splitwise":
|
||||
self.config = SplitWiseSchedulerConfig(**kwargs)
|
||||
|
||||
if name == "dp":
|
||||
self.config = DPLocalSchedulerConfig(**kwargs)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Validate the configuration.
|
||||
@@ -236,7 +291,7 @@ class SchedulerConfig:
|
||||
Raises:
|
||||
Exception: If invalid scheduler type is specified
|
||||
"""
|
||||
if self.name not in ["local", "global", "splitwise"]:
|
||||
if self.name not in ["local", "global", "splitwise", "dp"]:
|
||||
raise Exception(f"Unknown scheduler type {self.name}")
|
||||
|
||||
self.config.check()
|
||||
@@ -274,6 +329,17 @@ class SchedulerConfig:
|
||||
if self.name == "splitwise":
|
||||
return SplitWiseScheduler(self.config)
|
||||
|
||||
if self.name == "dp":
|
||||
return DPScheduler(
|
||||
max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
enable_chunked_prefill=self.config.enable_chunked_prefill,
|
||||
max_num_partial_prefills=self.config.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.config.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
|
||||
splitwise_role=self.config.splitwise_role,
|
||||
)
|
||||
|
||||
return LocalScheduler(
|
||||
max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
|
179
fastdeploy/scheduler/dp_scheduler.py
Normal file
179
fastdeploy/scheduler/dp_scheduler.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing import Queue
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledResponse
|
||||
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
|
||||
|
||||
class DPLocalScheduler(LocalScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
splitwise_role: str = "prefill",
|
||||
):
|
||||
super().__init__(
|
||||
max_size,
|
||||
ttl,
|
||||
enable_chunked_prefill,
|
||||
max_num_partial_prefills,
|
||||
max_long_partial_prefills,
|
||||
long_prefill_token_threshold,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
Add processing results back to the scheduler.
|
||||
Args:
|
||||
results: List of RequestOutput objects containing results
|
||||
"""
|
||||
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
|
||||
|
||||
finished_responses = [response.request_id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||
|
||||
with self.mutex:
|
||||
for response in responses:
|
||||
if response.request_id not in self.responses:
|
||||
self.responses[response.request_id] = [response]
|
||||
continue
|
||||
self.responses[response.request_id].append(response)
|
||||
self.responses_not_empty.notify_all()
|
||||
|
||||
def _recycle(self, request_id: Optional[str] = None):
|
||||
"""
|
||||
Clean up expired or completed requests to free memory.
|
||||
Args:
|
||||
request_id: Optional specific request ID to remove.
|
||||
If None, removes all expired requests.
|
||||
"""
|
||||
if request_id is not None:
|
||||
self.requests.pop(request_id, None)
|
||||
self.responses.pop(request_id, None)
|
||||
if self.splitwise_role == "decode":
|
||||
return
|
||||
self.ids.pop(self.ids.index(request_id))
|
||||
self.ids_read_cursor -= 1
|
||||
return
|
||||
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
|
||||
if len(self.requests) <= self.max_size:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
for request_id in self.ids:
|
||||
request = self.requests[request_id]
|
||||
if now - request.schedule_time < self.ttl:
|
||||
break
|
||||
expired_ids.append(request.request_id)
|
||||
|
||||
for i, expired_id in enumerate(expired_ids):
|
||||
self.requests.pop(expired_id, None)
|
||||
self.responses.pop(expired_id, None)
|
||||
self.ids.pop(i)
|
||||
|
||||
if len(expired_ids) > 0:
|
||||
if len(expired_ids) - 1 >= self.ids_read_cursor:
|
||||
self.ids_read_cursor = 0
|
||||
else:
|
||||
self.ids_read_cursor -= len(expired_ids)
|
||||
|
||||
|
||||
class DPScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
splitwise_role: str = "prefill",
|
||||
):
|
||||
self._scheduler = DPLocalScheduler(
|
||||
max_size,
|
||||
ttl,
|
||||
enable_chunked_prefill,
|
||||
max_num_partial_prefills,
|
||||
max_long_partial_prefills,
|
||||
long_prefill_token_threshold,
|
||||
splitwise_role,
|
||||
)
|
||||
|
||||
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
|
||||
self.dp_rank = dp_rank
|
||||
self.request_queues = request_queues
|
||||
self.result_queue = result_queue
|
||||
threading.Thread(target=self._put_requests_to_local).start()
|
||||
threading.Thread(target=self._get_response_from_local).start()
|
||||
|
||||
def put_requests(self, requests: List[Dict]):
|
||||
results = []
|
||||
for request in requests:
|
||||
if not hasattr(request, "dp_rank"):
|
||||
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
|
||||
self.request_queues[request.dp_rank].put(request)
|
||||
results.append((request.request_id, None))
|
||||
return results
|
||||
|
||||
def _put_requests_to_local(self):
|
||||
while True:
|
||||
request = self.request_queues[self.dp_rank].get()
|
||||
self._scheduler.put_requests([request])
|
||||
|
||||
def _get_response_from_local(self):
|
||||
while True:
|
||||
results = self._scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
continue
|
||||
self.result_queue.put(results)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
return self._scheduler.get_requests(
|
||||
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
|
||||
)
|
||||
|
||||
def get_unhandled_request_num(self):
|
||||
return len(self._scheduler.requests)
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
self._scheduler.put_results(results)
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
return self.result_queue.get()
|
107
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
107
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# **Note**: Just for internal use
|
||||
import zmq
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqTcpServer
|
||||
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
|
||||
|
||||
|
||||
class InternalAdapter:
|
||||
def __init__(self, cfg, engine, dp_rank):
|
||||
self.cfg = cfg
|
||||
self.engine = engine
|
||||
self.dp_rank = dp_rank
|
||||
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
|
||||
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
|
||||
self.recv_external_instruct_thread = threading.Thread(
|
||||
target=self._recv_external_module_control_instruct, daemon=True
|
||||
)
|
||||
self.recv_external_instruct_thread.start()
|
||||
self.response_external_instruct_thread = threading.Thread(
|
||||
target=self._response_external_module_control_instruct, daemon=True
|
||||
)
|
||||
self.response_external_instruct_thread.start()
|
||||
|
||||
def _get_current_server_info(self):
|
||||
"""
|
||||
Get resources information
|
||||
"""
|
||||
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
||||
|
||||
available_block_num = self.engine.resource_manager.available_block_num()
|
||||
server_info = {
|
||||
"splitwise_role": self.cfg.splitwise_role,
|
||||
"block_size": int(self.cfg.cache_config.block_size),
|
||||
"block_num": int(available_block_num),
|
||||
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
|
||||
"available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num,
|
||||
"max_batch_size": int(available_batch_size),
|
||||
"max_input_token_num": self.cfg.max_num_batched_tokens,
|
||||
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
|
||||
}
|
||||
return server_info
|
||||
|
||||
def _recv_external_module_control_instruct(self):
|
||||
"""
|
||||
Receive a multipart message from the control cmd socket.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
task = self.recv_control_cmd_server.recv_control_cmd()
|
||||
logger.info(f"Recieve control task: {task}")
|
||||
task_id_str = task["task_id"]
|
||||
if task["cmd"] == "get_payload":
|
||||
payload_info = self._get_current_server_info()
|
||||
result = {"task_id": task_id_str, "result": payload_info}
|
||||
logger.info(f"Response for task: {task_id_str}")
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
|
||||
elif task["cmd"] == "get_metrics":
|
||||
metrics_text = get_filtered_metrics(
|
||||
[],
|
||||
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
|
||||
)
|
||||
result = {"task_id": task_id_str, "result": metrics_text}
|
||||
logger.info(f"Response for task: {task_id_str}")
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
elif task["cmd"] == "connect_rdma":
|
||||
self.engine.engine_worker_queue.put_connect_rdma_task(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def _response_external_module_control_instruct(self):
|
||||
while True:
|
||||
try:
|
||||
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
|
||||
if result_data:
|
||||
task_id_str = result_data["task_id"]
|
||||
result = {"task_id": task_id_str, "result": result_data}
|
||||
logger.info(f"Response for task: {task_id_str}")
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")
|
Reference in New Issue
Block a user