Launch expert_service before kv_cache initialization in worker_process (#3045)

* launch expert_service before kv_cache initialization

* add two signal make sure model loading and expert_service lauching finished

* fix the EP bug

* fix ep

* update launching way

* fix ep

* update

* roback ep

* pre-commit all files

---------

Co-authored-by: RAM <gstian5555@outlook.com>
Co-authored-by: Divano <dddivano@outlook.com>
This commit is contained in:
Zero Rains
2025-08-11 19:38:46 +08:00
committed by GitHub
parent c27a3dc43b
commit b23af29d0b
6 changed files with 175 additions and 100 deletions

View File

@@ -26,7 +26,7 @@ import weakref
import numpy as np
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
@@ -127,7 +127,7 @@ class ExpertService:
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.pod_ips[0],
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
@@ -141,16 +141,29 @@ class ExpertService:
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
self.launched_expert_service_signal.value[local_rank] = 1
console_logger.info(
f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds."
)
return True
def _insert_task_to_worker(self):