From c499bd9e90ddb34b8da68c9b7d71a580cbb9f208 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 25 Nov 2025 06:24:39 +0000 Subject: [PATCH] Remove lock in get_task/put_task --- fastdeploy/engine/common_engine.py | 2 +- .../inter_communicator/engine_worker_queue.py | 15 +++++++++++++++ fastdeploy/worker/worker_process.py | 11 ++++++----- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index ef39ba34c..dc540c0d7 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -848,7 +848,7 @@ class EngineService: trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", "")) - self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) + self.engine_worker_queue.put_tasks_v1((tasks, self.resource_manager.real_bsz)) self.insert_task_signal.value[0] = 1 # 4. Response error tasks diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index d6c9993fb..fafbd2b50 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -496,6 +496,21 @@ class EngineWorkerQueue: self.tasks.append(tasks) self.lock.release() + def put_tasks_v1(self, tasks: List[Any]) -> None: + if envs.FD_ENABLE_MAX_PREFILL or envs.FD_ENABLE_E2W_TENSOR_CONVERT: + # multimodal input numpy -> tensor + to_tensor(tasks[0]) + self.tasks[:] = list() + self.tasks.append(tasks) + + def get_tasks_v1(self) -> Tuple[List[Any], bool]: + tasks = list() + tasks.extend(self.tasks) + return tasks + + def clear_tasks_v1(self): + self.tasks[:] = list() + def get_tasks(self) -> Tuple[List[Any], bool]: """ Retrieve tasks from the shared queue and update read status. diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 21d7f6482..7d561a241 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -487,11 +487,7 @@ class PaddleDisWorkerProc: logger.info(f"Rank: {self.local_rank} Detected new requests.") self.insert_step = True - tasks, read_finish = self.task_queue.get_tasks() - if read_finish: - # Ensure that every worker get the task - self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY - self.insert_task_signal.value[0] = 0 + tasks = self.task_queue.get_tasks_v1() req_dicts = [] for req_dict, bsz in tasks: @@ -518,6 +514,11 @@ class PaddleDisWorkerProc: # These generated tokens can be obtained through get_output op. start_execute_time = time.time() self.worker.execute_model(req_dicts, num_running_requests) + if tp_rank == 0 and req_dicts is not None: + self.insert_task_signal.value[0] = 0 + self.task_queue.clear_tasks_v1() + self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY + self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill() logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")