[Feature] Support ep pd with external module (#3194)

* Support external module

* Support external module

* Support external module

* Support external module

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix bug

* fix bug

* fix bug

* merge

---------

Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
This commit is contained in:
chenjian
2025-08-04 20:32:41 +08:00
committed by GitHub
parent 0443587a57
commit 9f9971844f
15 changed files with 876 additions and 218 deletions

View File

@@ -29,8 +29,9 @@ from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
class ExpertService:
@@ -60,7 +61,8 @@ class ExpertService:
self.scheduler = cfg.scheduler_config.scheduler()
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
@@ -111,8 +113,12 @@ class ExpertService:
)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
def start(self, ipc_signal_suffix, local_data_parallel_id):
def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
@@ -127,7 +133,7 @@ class ExpertService:
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.pod_ips[0],
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
@@ -147,7 +153,11 @@ class ExpertService:
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.scheduler.start(role, host_ip, disaggregate)
if self.cfg.scheduler_config.name == "dp":
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
elif self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
@@ -356,13 +366,17 @@ class ExpertService:
self.zmq_server.close()
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
def start_expert_service(
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
expert_service.split_connector.start_receiver()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}")