Launch expert_service before kv_cache initialization in worker_process (#3045)

* launch expert_service before kv_cache initialization

* add two signal make sure model loading and expert_service lauching finished

* fix the EP bug

* fix ep

* update launching way

* fix ep

* update

* roback ep

* pre-commit all files

---------

Co-authored-by: RAM <gstian5555@outlook.com>
Co-authored-by: Divano <dddivano@outlook.com>
This commit is contained in:
Zero Rains
2025-08-11 19:38:46 +08:00
committed by GitHub
parent c27a3dc43b
commit b23af29d0b
6 changed files with 175 additions and 100 deletions

View File

@@ -196,13 +196,42 @@ class LLMEngine:
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix, pid_suffix=self.ipc_signal_suffix,
) )
self.launched_cache_manager_signal.value[0] = 1
self.worker_proc = self._start_worker_service() self.worker_proc = self._start_worker_service()
console_logger.info("Waiting worker processes ready...") console_logger.info("Waiting worker processes ready...")
time.sleep(5) time.sleep(5)
self.worker_init_status = dict() self.worker_init_status = dict()
if not self.check_worker_initialize_status():
result_container = {}
def check_worker_initialize_status_func(res: dict):
res["worker_is_alive"] = True
if not self.check_worker_initialize_status():
console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
res["worker_is_alive"] = False
self.check_worker_initialize_status_func_thread = threading.Thread(
target=check_worker_initialize_status_func, args=(result_container,), daemon=True
)
self.check_worker_initialize_status_func_thread.start()
# Wait model loading
while self.loaded_model_signal.value[0] == 0:
# Make sure worker process is alive
if not self.check_worker_initialize_status_func_thread.is_alive():
return False
time.sleep(1)
if self.do_profile:
self._stop_profile()
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
# Worker launched
self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]:
console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
return False return False
@@ -214,68 +243,6 @@ class LLMEngine:
self._del_warmup_token_processor() self._del_warmup_token_processor()
console_logger.info("Warmup finished") console_logger.info("Warmup finished")
self.token_processor.tasks_queue = self.engine_worker_queue
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
if self.api_server_pid is not None:
self.insert_task_to_scheduler_thread = threading.Thread(
target=self._insert_zmq_task_to_scheduler, daemon=True
)
self.insert_task_to_scheduler_thread.start()
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start()
# Start TokenProcessor thread
self.token_processor.run()
if self.cfg.splitwise_role != "mixed":
# 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1)
self.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
time.sleep(1)
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True return True
@@ -909,7 +876,7 @@ class LLMEngine:
create=True, create=True,
) )
# exist_task_signal 用于各worker进程感知是否有新Task需要处理 # exist_task_signal: Used by each worker process to detect whether there is a new task to be processed
exist_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32) exist_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
self.exist_task_signal = IPCSignal( self.exist_task_signal = IPCSignal(
name="exist_task_signal", name="exist_task_signal",
@@ -919,7 +886,7 @@ class LLMEngine:
create=True, create=True,
) )
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task # exist_swapped_task_signal: Used by the engine to detect whether there is a swapped task in the worker
exist_swapped_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32) exist_swapped_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal( self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal", name="exist_swapped_task_signal",
@@ -929,7 +896,7 @@ class LLMEngine:
create=True, create=True,
) )
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill # exist_prefill_task_signal: Used by each worker process to detect whether to prefill
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal( self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal", name="exist_prefill_task_signal",
@@ -939,7 +906,7 @@ class LLMEngine:
create=True, create=True,
) )
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager # launched_cache_manager_signal: Used to detect whether the engine has started cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal( self.launched_cache_manager_signal = IPCSignal(
@@ -950,7 +917,30 @@ class LLMEngine:
create=True, create=True,
) )
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间 # launched_expert_service_signal: Used to sense whether each expet_servic is started successfully
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# loaded_model_signal: Used to detect whether each worker has completed model loading
loaded_model_signal_data = np.zeros([1], dtype=np.int32)
self.loaded_model_signal = IPCSignal(
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# worker_live_signal: Used by the engine to detect whether each worker process is alive and record the time of each step
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal( self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal", name="worker_healthy_live_signal",
@@ -1187,7 +1177,7 @@ class LLMEngine:
llm_logger.error(f"Error happend while adding request, details={e}") llm_logger.error(f"Error happend while adding request, details={e}")
raise EngineError(str(e), error_code=400) raise EngineError(str(e), error_code=400)
# 获取当前请求的结果 # Get the result of the current request
for result in self._get_generated_tokens(req_id): for result in self._get_generated_tokens(req_id):
is_end = result.finished is_end = result.finished
if stream and not is_end: if stream and not is_end:
@@ -1231,7 +1221,6 @@ class LLMEngine:
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix, pid_suffix=self.ipc_signal_suffix,
) )
self.launched_cache_manager_signal.value[0] = 1
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """
@@ -1245,6 +1234,72 @@ class LLMEngine:
return True, "" return True, ""
def launch_components(self):
self.token_processor.tasks_queue = self.engine_worker_queue
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
if self.api_server_pid is not None:
self.insert_task_to_scheduler_thread = threading.Thread(
target=self._insert_zmq_task_to_scheduler, daemon=True
)
self.insert_task_to_scheduler_thread.start()
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start()
# Start TokenProcessor thread
self.token_processor.run()
if self.cfg.splitwise_role != "mixed":
# 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1)
self.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
time.sleep(1)
expert_service_nums = self.cfg.parallel_config.data_parallel_size // self.cfg.nnode
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(
1,
expert_service_nums,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
for i in range(1, expert_service_nums):
while self.launched_expert_service_signal.value[i] == 0:
time.sleep(10)
def check_worker_initialize_status(self): def check_worker_initialize_status(self):
""" """
Check the initlialize status of workers by stdout logging Check the initlialize status of workers by stdout logging
@@ -1270,10 +1325,6 @@ class LLMEngine:
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start() self.checking_worker_status_thread.start()
checking_worker_init_kv_cache_status_thread = None
if self.do_profile:
checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True)
checking_worker_init_kv_cache_status_thread.start()
# display weight loadding progress # display weight loadding progress
with tqdm(total=100, desc="Loading Weights") as pbar: with tqdm(total=100, desc="Loading Weights") as pbar:
@@ -1304,8 +1355,6 @@ class LLMEngine:
self.worker_init_status["finished"] = True self.worker_init_status["finished"] = True
try: try:
self.checking_worker_status_thread.join(timeout=1) self.checking_worker_status_thread.join(timeout=1)
if checking_worker_init_kv_cache_status_thread is not None:
checking_worker_init_kv_cache_status_thread.join(timeout=1)
except Exception: except Exception:
pass pass
return True return True

View File

@@ -26,7 +26,7 @@ import weakref
import numpy as np import numpy as np
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
@@ -127,7 +127,7 @@ class ExpertService:
cache_config=self.cfg.cache_config, cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size, tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids, device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.pod_ips[0], pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
) )
@@ -141,16 +141,29 @@ class ExpertService:
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port)) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run() self.token_processor.run()
self.cfg.init_cache_info() self.cfg.init_cache_info()
role = self.cfg.splitwise_role role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info disaggregate = self.cfg.disaggregate_info
self.scheduler.start(role, host_ip, disaggregate) self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print() self.cfg.print()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
self.launched_expert_service_signal.value[local_rank] = 1
console_logger.info(
f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds."
)
return True return True
def _insert_task_to_worker(self): def _insert_task_to_worker(self):

View File

@@ -97,13 +97,13 @@ class ResourceManagerV1(ResourceManager):
def _prepare_preempt_task(self, request): def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
def reschedule_preempt_task(self, request_id): def reschedule_preempt_task(self, request_id):
with self.lock: with self.lock:
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests: if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id] request = self.requests[request_id]
self.waiting.appendleft(request) self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id) self.to_be_rescheduled_request_id_set.remove(request_id)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
can_schedule = True can_schedule = True
@@ -422,9 +422,15 @@ class ResourceManagerV1(ResourceManager):
self.running.remove(request) self.running.remove(request)
request.status = RequestStatus.FINISHED request.status = RequestStatus.FINISHED
self._free_blocks(request) self._free_blocks(request)
if request.request_id in self.to_be_rescheduled_request_id_set: # finished after preempted, blocks have been recycled. if (
self.to_be_rescheduled_request_id_set.remove(request.request_id) # just remove from to_be_rescheduled_request_id_set request.request_id in self.to_be_rescheduled_request_id_set
if request in self.waiting: # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here ): # finished after preempted, blocks have been recycled.
self.to_be_rescheduled_request_id_set.remove(
request.request_id
) # just remove from to_be_rescheduled_request_id_set
if (
request in self.waiting
): # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished") raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished")
self.tasks_list[request.idx] = None self.tasks_list[request.idx] = None

View File

@@ -296,12 +296,14 @@ class TokenProcessor:
else: else:
batch = self.output_tokens[1, 0] batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2] tokens = tokens[2 : batch + 2]
batch_result = list() batch_result = list()
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set) need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
for request_id in need_to_be_reschedule_req_ids: for request_id in need_to_be_reschedule_req_ids:
if self.resource_manager.requests[request_id].idx >= (batch - 1): # No more token generated for preempted request if self.resource_manager.requests[request_id].idx >= (
batch - 1
): # No more token generated for preempted request
self.resource_manager.reschedule_preempt_task(request_id) self.resource_manager.reschedule_preempt_task(request_id)
for i in range(batch): for i in range(batch):
if self.resource_manager.stop_flags[i]: if self.resource_manager.stop_flags[i]:

View File

@@ -431,7 +431,19 @@ class PaddleDisWorkerProc:
def load_model(self) -> None: def load_model(self) -> None:
"""Load weights and create model""" """Load weights and create model"""
self.worker.load_model() self.worker.load_model()
loaded_model_signal_data = np.zeros(shape=[1], dtype=np.int32)
self.loaded_model_signal = IPCSignal(
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
if self.ranks > 1:
paddle.distributed.barrier()
self.loaded_model_signal.value[0] = 1
def parse_args(): def parse_args():

View File

@@ -7,23 +7,16 @@
Boundary value checking for API parameters Boundary value checking for API parameters
""" """
import json
from core import ( from core import TEMPLATE, URL, build_request_payload, send_request
TEMPLATE,
URL,
build_request_payload,
send_request,
)
def test_max_min_1_token(): def test_max_min_1_token():
data = { data = {
"stream": False, "stream": False,
"messages": [{"role": "user", "content": "非洲的首都是?"}], "messages": [{"role": "user", "content": "非洲的首都是?"}],
"max_tokens": 1, "max_tokens": 1,
"metadata": { "metadata": {"min_tokens": 1},
"min_tokens": 1
},
} }
payload = build_request_payload(TEMPLATE, data) payload = build_request_payload(TEMPLATE, data)
response = send_request(URL, payload).json() response = send_request(URL, payload).json()
@@ -33,4 +26,4 @@ def test_max_min_1_token():
completion_tokens = response["usage"]["completion_tokens"] completion_tokens = response["usage"]["completion_tokens"]
assert completion_tokens == 1, f"实际生成的token数为: {completion_tokens}, 应该为1" assert completion_tokens == 1, f"实际生成的token数为: {completion_tokens}, 应该为1"
finish_reason = response["choices"][0]["finish_reason"] finish_reason = response["choices"][0]["finish_reason"]
assert finish_reason == "length", f"内容不可能完整生成, 但实际finish_reason为: {response}" assert finish_reason == "length", f"内容不可能完整生成, 但实际finish_reason为: {response}"