[Optimize] Optimize worker process comm timecost

This commit is contained in:
root
2025-11-25 03:57:05 +00:00
parent edf0d09257
commit 7cada8627f
2 changed files with 35 additions and 4 deletions

View File

@@ -182,6 +182,25 @@ class EngineService:
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
self.llm_logger.info(f"current_suffix: {current_suffix}")
insert_task_signal_data = np.zeros([1], dtype=np.int32)
self.insert_task_signal = IPCSignal(
name="insert_task_signal",
array=insert_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=exist_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
@@ -792,7 +811,7 @@ class EngineService:
while self.running:
try:
if self.engine_worker_queue.num_tasks() > 0:
if self.insert_task_signal.value[0] == 1 or self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role != "mixed":
@@ -840,6 +859,7 @@ class EngineService:
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
self.insert_task_signal.value[0] = 1
# 4. Response error tasks
if error_tasks:

View File

@@ -229,6 +229,15 @@ class PaddleDisWorkerProc:
create=False,
)
insert_task_signal_data = np.zeros([1], dtype=np.int32)
self.insert_task_signal = IPCSignal(
name="insert_task_signal",
array=insert_task_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init exist_task_signal
workers_exist_task = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
@@ -435,7 +444,9 @@ class PaddleDisWorkerProc:
# The first worker detects whether there are tasks in the task queue
if tp_rank == 0:
if self.task_queue.num_tasks() > 0:
start = time.perf_counter()
if self.insert_task_signal.value[0] == 1:
#if self.task_queue.num_tasks() > 0:
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
self.fd_config.model_config.enable_mm and self.worker.exist_prefill()
):
@@ -473,7 +484,7 @@ class PaddleDisWorkerProc:
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
@@ -481,7 +492,7 @@ class PaddleDisWorkerProc:
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self.task_queue.read_finish_flag.set(0)
self.insert_task_signal.value[0] = 0
req_dicts = []
for req_dict, bsz in tasks: