[Feature] Support ep pd with external module (#3194)

* Support external module

* Support external module

* Support external module

* Support external module

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix bug

* fix bug

* fix bug

* merge

---------

Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
This commit is contained in:
chenjian
2025-08-04 20:32:41 +08:00
committed by GitHub
parent 0443587a57
commit 9f9971844f
15 changed files with 876 additions and 218 deletions

View File

@@ -142,12 +142,16 @@ class CacheMessager:
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.cache_info = dict() self.cache_info = dict()
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start() layerwise_send_cache_thread.start()
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True
connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}") logger.info(f"cache messager init finished, use {transfer_protocol}")
def _prefill_layerwise_send_cache_thread(self): def _prefill_layerwise_send_cache_thread(self):
@@ -160,14 +164,14 @@ class CacheMessager:
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
try: try:
step_shm_value = IPCSignal( step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", name=f"splitwise_complete_prefilled_step_{self.rank_id}",
array=prefilled_step_idx_data, array=prefilled_step_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=True, create=True,
) )
layer_shm_value = IPCSignal( layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
array=prefilled_layer_idx_data, array=prefilled_layer_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
@@ -175,14 +179,14 @@ class CacheMessager:
) )
except: except:
step_shm_value = IPCSignal( step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", name=f"splitwise_complete_prefilled_step_{self.rank_id}",
array=prefilled_step_idx_data, array=prefilled_step_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=False, create=False,
) )
layer_shm_value = IPCSignal( layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
array=prefilled_layer_idx_data, array=prefilled_layer_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
@@ -310,3 +314,22 @@ class CacheMessager:
except Exception as e: except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e}") logger.error(f"prefill layerwise send cache thread has exception: {e}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")

View File

@@ -820,6 +820,7 @@ class EngineArgs:
"max_num_partial_prefills", "max_num_partial_prefills",
"max_long_partial_prefills", "max_long_partial_prefills",
"long_prefill_token_threshold", "long_prefill_token_threshold",
"splitwise_role"
] ]
all = asdict(self) all = asdict(self)

View File

@@ -47,12 +47,14 @@ from fastdeploy.inter_communicator import (
EngineCacheQueue, EngineCacheQueue,
EngineWorkerQueue, EngineWorkerQueue,
IPCSignal, IPCSignal,
ZmqClient, ZmqIpcServer,
ZmqTcpServer,
) )
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
@@ -179,11 +181,64 @@ class LLMEngine:
self.data_processor = self.input_processor.create_processor() self.data_processor = self.input_processor.create_processor()
if api_server_pid is not None: if api_server_pid is not None:
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.zmq_server.start_server() self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
self.zmq_server.create_router() self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
self.external_adapter = InternalAdapter(
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
)
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
time.sleep(3) time.sleep(3)
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = (
None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
)
result_queue_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
time.sleep(1)
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
if self.do_profile == 0 and ( if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
): ):
@@ -238,44 +293,11 @@ class LLMEngine:
# 单机逻辑 # 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1) self.engine_worker_queue.available_prefill_instances.put(1)
self.split_mode_get_tasks() self.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise": if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp":
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start() self.splitwise_receive_thread.start()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
time.sleep(1)
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True return True
@@ -291,7 +313,7 @@ class LLMEngine:
time.sleep(0.005) time.sleep(0.005)
continue continue
for request_id, contents in results.items(): for request_id, contents in results.items():
self.zmq_server.send_multipart(request_id, contents) self.send_response_server.send_response(request_id, contents)
except Exception as e: except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
@@ -415,14 +437,18 @@ class LLMEngine:
if self.api_server_pid is None: if self.api_server_pid is None:
return return
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.splitwise_role == "decode":
return
added_requests: Dict[str, int] = dict() added_requests: Dict[str, int] = dict()
while self.running: while self.running:
try: try:
block = True if len(added_requests) == 0 else False block = True if len(added_requests) == 0 else False
if not self.cfg.enable_mm: if not self.cfg.enable_mm:
err, data = self.zmq_server.receive_json_once(block) err, data = self.recv_request_server.receive_json_once(block)
else: else:
err, data = self.zmq_server.receive_pyobj_once(block) err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None: if err is not None:
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
break break
@@ -470,7 +496,7 @@ class LLMEngine:
) )
# Since the request is not in scheduler # Since the request is not in scheduler
# Send result by zmq directly # Send result by zmq directly
self.zmq_server.send_multipart(request_id, error_result) self.send_response_server.send_response(request_id, error_result)
except Exception as e: except Exception as e:
llm_logger.error( llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, " f"Error happend while receving new request from zmq, details={e}, "
@@ -989,8 +1015,12 @@ class LLMEngine:
print(f"Error extracting sub services: {e}") print(f"Error extracting sub services: {e}")
self.engine_worker_queue.cleanup() self.engine_worker_queue.cleanup()
if hasattr(self, "zmq_server") and self.zmq_server is not None: if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.zmq_server.close() self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()
if hasattr(self, "dp_processed"): if hasattr(self, "dp_processed"):
for p in self.dp_processed: for p in self.dp_processed:
p.join() p.join()

View File

@@ -29,8 +29,9 @@ from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
class ExpertService: class ExpertService:
@@ -60,6 +61,7 @@ class ExpertService:
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
@@ -111,8 +113,12 @@ class ExpertService:
) )
self._finalizer = weakref.finalize(self, self._exit_sub_services) self._finalizer = weakref.finalize(self, self._exit_sub_services)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
def start(self, ipc_signal_suffix, local_data_parallel_id): def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
""" """
Initializes the engine and starts its sub-services. Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread If `api_server_pid` is defined, will launch a thread
@@ -127,7 +133,7 @@ class ExpertService:
cache_config=self.cfg.cache_config, cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size, tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids, device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.pod_ips[0], pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
) )
@@ -147,6 +153,10 @@ class ExpertService:
role = self.cfg.splitwise_role role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "dp":
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
elif self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate) self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print() self.cfg.print()
@@ -356,13 +366,17 @@ class ExpertService:
self.zmq_server.close() self.zmq_server.close()
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix): def start_expert_service(
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
""" """
Start expert service Start expert service
""" """
expert_service = ExpertService(cfg, local_data_parallel_id) expert_service = ExpertService(cfg, local_data_parallel_id)
try: try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id) expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
expert_service.split_connector.start_receiver() expert_service.split_connector.start_receiver()
except Exception as e: except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}") llm_logger.exception(f"Expert service failed to start: {e}")

View File

@@ -71,6 +71,7 @@ class Request:
guided_json_object: Optional[bool] = None, guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True, enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(), trace_carrier: dict = dict(),
dp_rank: Optional[int] = None
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@@ -119,6 +120,7 @@ class Request:
self.task_type = RequestType.PREFILL self.task_type = RequestType.PREFILL
self.idx = None self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len self.need_prefill_tokens = self.prompt_token_ids_len
self.dp_rank = dp_rank
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
@@ -151,6 +153,7 @@ class Request:
guided_json_object=d.get("guided_json_object", None), guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True), enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {}), trace_carrier=d.get("trace_carrier", {}),
dp_rank=d.get("dp_rank", None)
) )
@property @property

View File

@@ -21,7 +21,7 @@ import numpy as np
from fastdeploy.engine.config import ModelConfig from fastdeploy.engine.config import ModelConfig
from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -90,7 +90,7 @@ class EngineClient:
""" """
Create a ZMQ client. Create a ZMQ client.
""" """
self.zmq_client = ZmqClient(model, mode) self.zmq_client = ZmqIpcClient(model, mode)
self.zmq_client.connect() self.zmq_client.connect()
def format_and_add_data(self, prompts: dict): def format_and_add_data(self, prompts: dict):

View File

@@ -80,6 +80,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
# enable kv cache block scheduler v1 (no need for kv_cache_ratio) # enable kv cache block scheduler v1 (no need for kv_cache_ratio)
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
# enable internal module to access LLMEngine.
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to use PLUGINS. # Whether to use PLUGINS.
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
} }

View File

@@ -17,6 +17,7 @@
from .engine_cache_queue import EngineCacheQueue from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal from .ipc_signal import IPCSignal
from .zmq_client import ZmqClient from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"] __all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"]

View File

@@ -85,12 +85,15 @@ class EngineWorkerQueue:
] ]
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)] self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
self.client_read_info_flag_init: List[List[int]] = [ self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size) [1] * self.num_client for _ in range(self.local_data_parallel_size)
] ]
self.lock_info_init: List[threading.Lock] = [ self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size) threading.Lock() for _ in range(self.local_data_parallel_size)
] ]
self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
self.finish_request_barrier = [ self.finish_request_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
@@ -112,11 +115,26 @@ class EngineWorkerQueue:
callable=lambda idx: self.lock_init[idx], callable=lambda idx: self.lock_init[idx],
proxytype=AcquirerProxy, proxytype=AcquirerProxy,
) )
QueueManager.register(
"get_connect_task_lock",
callable=lambda idx: self.connect_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register( QueueManager.register(
"get_read_finish_flag", "get_read_finish_flag",
callable=lambda idx: self.read_finish_flag_init[idx], callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy, proxytype=ValueProxy,
) )
QueueManager.register(
"get_connect_rdma_tasks",
callable=lambda idx: self.connect_rdma_tasks_list[idx],
proxytype=ListProxy
)
QueueManager.register(
"get_connect_rdma_tasks_responses",
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
proxytype=ListProxy
)
QueueManager.register( QueueManager.register(
"get_connected_client_counter", "get_connected_client_counter",
callable=lambda idx: self.connected_client_counter_init[idx], callable=lambda idx: self.connected_client_counter_init[idx],
@@ -180,6 +198,9 @@ class EngineWorkerQueue:
QueueManager.register("get_disaggregate_requests") QueueManager.register("get_disaggregate_requests")
QueueManager.register("get_available_prefill_instances") QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier") QueueManager.register("get_finish_request_barrier")
QueueManager.register("get_connect_rdma_tasks")
QueueManager.register("get_connect_rdma_tasks_responses")
QueueManager.register("get_connect_task_lock")
self.manager = QueueManager(address=self.address, authkey=self.authkey) self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry() self._connect_with_retry()
@@ -200,6 +221,13 @@ class EngineWorkerQueue:
self.available_prefill_instances = self.manager.get_available_prefill_instances() self.available_prefill_instances = self.manager.get_available_prefill_instances()
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
# p/d互联
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
self.local_data_parallel_id
)
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
assert self.num_client == len(self.client_read_flag) assert self.num_client == len(self.client_read_flag)
if is_server: if is_server:
@@ -281,6 +309,45 @@ class EngineWorkerQueue:
self.lock.release() self.lock.release()
return total_num return total_num
def put_connect_rdma_task(self, connect_rdma_task):
self.connect_task_lock.acquire()
self.connect_rdma_task_queue.append(connect_rdma_task)
self.connect_task_lock.release()
def get_connect_rdma_task(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def put_connect_rdma_task_response(self, connect_rdma_task_response):
self.connect_task_lock.acquire()
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
self.connect_task_lock.release()
def get_connect_rdma_task_response(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_response_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_response_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def get_prefill_instances(self): def get_prefill_instances(self):
""" """
check if the prefill queue is empty check if the prefill queue is empty

View File

@@ -14,200 +14,78 @@
# limitations under the License. # limitations under the License.
""" """
import os from abc import ABC, abstractmethod
import threading
import time
import msgpack
import zmq import zmq
from fastdeploy import envs
from fastdeploy.utils import llm_logger
class ZmqClientBase(ABC):
class ZmqClient:
""" """
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
""" """
def __init__(self, name, mode): def __init__(self):
self.context = zmq.Context() pass
self.socket = self.context.socket(mode)
self.file_name = f"/dev/shm/{name}.socket"
self.router_path = f"/dev/shm/router_{name}.ipc"
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) @abstractmethod
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass
self.mutex = threading.Lock() def _ensure_socket(self):
self.req_dict = dict() """Ensure the socket is created before use."""
self.router = None if self.socket is None:
self.poller = None self.socket = self._create_socket()
self.running = True
@abstractmethod
def connect(self): def connect(self):
""" """
Connect to the server using the file name specified in the constructor. Connect to the server using the file name specified in the constructor.
""" """
self.socket.connect(f"ipc://{self.file_name}") pass
def start_server(self):
"""
Start the server using the file name specified in the constructor.
"""
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def create_router(self):
"""
Create a ROUTER socket and bind it to the specified router path.
"""
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}")
def send_json(self, data): def send_json(self, data):
""" """
Send a JSON-serializable object over the socket. Send a JSON-serializable object over the socket.
""" """
self._ensure_socket()
self.socket.send_json(data) self.socket.send_json(data)
def recv_json(self): def recv_json(self):
""" """
Receive a JSON-serializable object from the socket. Receive a JSON-serializable object from the socket.
""" """
self._ensure_socket()
return self.socket.recv_json() return self.socket.recv_json()
def send_pyobj(self, data): def send_pyobj(self, data):
""" """
Send a Pickle-serializable object over the socket. Send a Pickle-serializable object over the socket.
""" """
self._ensure_socket()
self.socket.send_pyobj(data) self.socket.send_pyobj(data)
def recv_pyobj(self): def recv_pyobj(self):
""" """
Receive a Pickle-serializable object from the socket. Receive a Pickle-serializable object from the socket.
""" """
self._ensure_socket()
return self.socket.recv_pyobj() return self.socket.recv_pyobj()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
def send_multipart(self, req_id, data): class ZmqIpcClient(ZmqClientBase):
""" def __init__(self, name, mode):
Send a multipart message to the router socket. self.name = name
""" self.mode = mode
if self.router is None: self.file_name = f"/dev/shm/{name}.socket"
raise RuntimeError("Router socket not created. Call create_router() first.") self.context = zmq.Context()
self.socket = self.context.socket(self.mode)
while self.running: def _create_socket(self):
with self.mutex: """create and return a ZeroMQ socket."""
if req_id not in self.req_dict: self.context = zmq.Context()
try: return self.context.socket(self.mode)
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
try: def connect(self):
start_send = time.time() self._ensure_socket()
if self.aggregate_send: self.socket.connect(f"ipc://{self.file_name}")
result = self.pack_aggregated_data(data)
else:
result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if hasattr(self, "socket") and not self.socket.closed:
self.socket.close()
if self.router is not None and not self.router.closed:
self.router.close()
if not self.context.closed:
self.context.term()
self._clear_ipc(self.file_name)
self._clear_ipc(self.router_path)
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

View File

@@ -0,0 +1,273 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import threading
import time
from abc import ABC, abstractmethod
import msgpack
import zmq
from fastdeploy import envs
from fastdeploy.utils import llm_logger
class ZmqServerBase(ABC):
"""
ZmqServerBase
"""
def __init__(self):
pass
@abstractmethod
def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
self._ensure_socket()
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
self._ensure_socket()
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def send_response(self, req_id, data):
"""
Send generated token result to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created. Call create_router() first.")
while self.running:
with self.mutex:
if req_id not in self.req_dict:
try:
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
try:
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(data)
else:
result = msgpack.packb([response.to_dict() for response in data])
self.socket.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
@abstractmethod
def close(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class ZmqIpcServer(ZmqServerBase):
"""
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
"""
def __init__(self, name, mode):
self.name = name
self.mode = mode
if mode == zmq.PULL:
self.file_name = f"/dev/shm/{name}.socket"
elif mode == zmq.ROUTER:
self.file_name = f"/dev/shm/router_{name}.ipc"
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock()
self.req_dict = dict()
self.running = True
self.context = zmq.Context()
self._create_socket()
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.socket = self.context.socket(self.mode)
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
return self.socket
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if self.socket is not None and not self.socket.closed:
self.socket.close()
if not self.context.closed:
self.context.term()
self._clear_ipc(self.file_name)
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return
class ZmqTcpServer(ZmqServerBase):
"""
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
"""
def __init__(self, port, mode):
self.mode = mode
self.port = port
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock()
self.req_dict = dict()
self.running = True
self.context = zmq.Context()
self._create_socket()
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.socket = self.context.socket(self.mode)
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"tcp://*:{self.port}")
return self.socket
def recv_control_cmd(self):
"""
Recieve control command from client
"""
self._ensure_socket()
while self.running:
try:
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
task = msgpack.unpackb(task_data)
task_id_str = task["task_id"]
except zmq.Again:
time.sleep(0.001)
continue
with self.mutex:
self.req_dict[task_id_str] = client
return task
def response_for_control_cmd(self, task_id, result):
"""
Send command result back to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created.")
try:
result = msgpack.packb(result)
self.socket.send_multipart([self.req_dict[task_id], b"", result])
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
with self.mutex:
self.req_dict.pop(task_id, None)
llm_logger.info(f"response control cmd finished, task_id: {task_id}")
def close(self):
"""
Close the socket and context.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if self.socket is not None and not self.socket.closed:
self.socket.close()
if not self.context.closed:
self.context.term()
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return

View File

@@ -412,7 +412,11 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time) self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill)
break break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise": if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
batch_result.append(result) batch_result.append(result)
self.postprocess(batch_result) self.postprocess(batch_result)
@@ -531,7 +535,11 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time) self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill)
break break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise": if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
batch_result.append(result) batch_result.append(result)
self.postprocess(batch_result) self.postprocess(batch_result)

View File

@@ -18,6 +18,7 @@ import redis
from fastdeploy.utils import llm_logger from fastdeploy.utils import llm_logger
from .dp_scheduler import DPScheduler
from .global_scheduler import GlobalScheduler from .global_scheduler import GlobalScheduler
from .local_scheduler import LocalScheduler from .local_scheduler import LocalScheduler
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
@@ -89,6 +90,57 @@ class LocalSchedulerConfig:
llm_logger.info("=============================================================") llm_logger.info("=============================================================")
class DPLocalSchedulerConfig(LocalSchedulerConfig):
"""
Configuration class for DPLocalScheduler.
Attributes:
max_size: Maximum number of concurrent requests (-1 for unlimited)
ttl: Time-to-live in seconds for request expiration
"""
def __init__(
self,
max_size: int = -1,
ttl: int = 900,
max_model_len: int = 8192,
enable_chunked_prefill: bool = False,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
splitwise_role: str = "prefill",
**kwargs,
):
"""
Initialize LocalScheduler configuration.
Args:
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
ttl: Time-to-live in seconds for request expiration (default 900s)
max_model_len: Maximum model context length in tokens
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Max partial prefill operations allowed
max_long_partial_prefills: Max long-running partial prefill ops
long_prefill_token_threshold: Token count threshold for long prefill
**kwargs: Additional unused arguments (for forward compatibility)
Note:
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
- See LocalScheduler class for implementation details
"""
self.max_size = max_size
self.ttl = ttl
self.max_model_len = max_model_len
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.splitwise_role = splitwise_role
class GlobalSchedulerConfig: class GlobalSchedulerConfig:
""" """
Configuration class for GlobalScheduler (Redis-based). Configuration class for GlobalScheduler (Redis-based).
@@ -229,6 +281,9 @@ class SchedulerConfig:
if name == "splitwise": if name == "splitwise":
self.config = SplitWiseSchedulerConfig(**kwargs) self.config = SplitWiseSchedulerConfig(**kwargs)
if name == "dp":
self.config = DPLocalSchedulerConfig(**kwargs)
def check(self): def check(self):
""" """
Validate the configuration. Validate the configuration.
@@ -236,7 +291,7 @@ class SchedulerConfig:
Raises: Raises:
Exception: If invalid scheduler type is specified Exception: If invalid scheduler type is specified
""" """
if self.name not in ["local", "global", "splitwise"]: if self.name not in ["local", "global", "splitwise", "dp"]:
raise Exception(f"Unknown scheduler type {self.name}") raise Exception(f"Unknown scheduler type {self.name}")
self.config.check() self.config.check()
@@ -274,6 +329,17 @@ class SchedulerConfig:
if self.name == "splitwise": if self.name == "splitwise":
return SplitWiseScheduler(self.config) return SplitWiseScheduler(self.config)
if self.name == "dp":
return DPScheduler(
max_size=self.config.max_size,
ttl=self.config.ttl,
enable_chunked_prefill=self.config.enable_chunked_prefill,
max_num_partial_prefills=self.config.max_num_partial_prefills,
max_long_partial_prefills=self.config.max_long_partial_prefills,
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
splitwise_role=self.config.splitwise_role,
)
return LocalScheduler( return LocalScheduler(
max_size=self.config.max_size, max_size=self.config.max_size,
ttl=self.config.ttl, ttl=self.config.ttl,

View File

@@ -0,0 +1,179 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
from multiprocessing import Queue
from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import scheduler_logger
class DPLocalScheduler(LocalScheduler):
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
super().__init__(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
)
self.splitwise_role = splitwise_role
def put_results(self, results: List[RequestOutput]):
"""
Add processing results back to the scheduler.
Args:
results: List of RequestOutput objects containing results
"""
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
finished_responses = [response.request_id for response in responses if response.finished]
if len(finished_responses) > 0:
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
for response in responses:
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
continue
self.responses[response.request_id].append(response)
self.responses_not_empty.notify_all()
def _recycle(self, request_id: Optional[str] = None):
"""
Clean up expired or completed requests to free memory.
Args:
request_id: Optional specific request ID to remove.
If None, removes all expired requests.
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
if self.splitwise_role == "decode":
return
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if now - request.schedule_time < self.ttl:
break
expired_ids.append(request.request_id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
class DPScheduler:
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
self._scheduler = DPLocalScheduler(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
splitwise_role,
)
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queue = result_queue
threading.Thread(target=self._put_requests_to_local).start()
threading.Thread(target=self._get_response_from_local).start()
def put_requests(self, requests: List[Dict]):
results = []
for request in requests:
if not hasattr(request, "dp_rank"):
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
self.request_queues[request.dp_rank].put(request)
results.append((request.request_id, None))
return results
def _put_requests_to_local(self):
while True:
request = self.request_queues[self.dp_rank].get()
self._scheduler.put_requests([request])
def _get_response_from_local(self):
while True:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queue.put(results)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
return self._scheduler.get_requests(
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
)
def get_unhandled_request_num(self):
return len(self._scheduler.requests)
def put_results(self, results: List[RequestOutput]):
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queue.get()

View File

@@ -0,0 +1,107 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
import traceback
# **Note**: Just for internal use
import zmq
from fastdeploy.inter_communicator import ZmqTcpServer
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
from fastdeploy.utils import envs, get_logger
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
class InternalAdapter:
def __init__(self, cfg, engine, dp_rank):
self.cfg = cfg
self.engine = engine
self.dp_rank = dp_rank
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
self.recv_external_instruct_thread = threading.Thread(
target=self._recv_external_module_control_instruct, daemon=True
)
self.recv_external_instruct_thread.start()
self.response_external_instruct_thread = threading.Thread(
target=self._response_external_module_control_instruct, daemon=True
)
self.response_external_instruct_thread.start()
def _get_current_server_info(self):
"""
Get resources information
"""
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
available_block_num = self.engine.resource_manager.available_block_num()
server_info = {
"splitwise_role": self.cfg.splitwise_role,
"block_size": int(self.cfg.cache_config.block_size),
"block_num": int(available_block_num),
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
"available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num,
"max_batch_size": int(available_batch_size),
"max_input_token_num": self.cfg.max_num_batched_tokens,
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
}
return server_info
def _recv_external_module_control_instruct(self):
"""
Receive a multipart message from the control cmd socket.
"""
while True:
try:
task = self.recv_control_cmd_server.recv_control_cmd()
logger.info(f"Recieve control task: {task}")
task_id_str = task["task_id"]
if task["cmd"] == "get_payload":
payload_info = self._get_current_server_info()
result = {"task_id": task_id_str, "result": payload_info}
logger.info(f"Response for task: {task_id_str}")
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "get_metrics":
metrics_text = get_filtered_metrics(
[],
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
)
result = {"task_id": task_id_str, "result": metrics_text}
logger.info(f"Response for task: {task_id_str}")
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "connect_rdma":
self.engine.engine_worker_queue.put_connect_rdma_task(task)
except Exception as e:
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
def _response_external_module_control_instruct(self):
while True:
try:
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
if result_data:
task_id_str = result_data["task_id"]
result = {"task_id": task_id_str, "result": result_data}
logger.info(f"Response for task: {task_id_str}")
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
else:
time.sleep(0.001)
except Exception as e:
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")