diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 0cb301b89..e289e9897 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -173,7 +173,6 @@ class PaddleDisWorkerProc: exist_swapped_task_signal: model_weights_status: """ - self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 if ( self.parallel_config.enable_expert_parallel and self.parallel_config.data_parallel_size > 1 @@ -413,8 +412,7 @@ class PaddleDisWorkerProc: self._init_eplb_signal() tp_size = self.parallel_config.tensor_parallel_size # Currently, only support single node - self.nnode = int((tp_size + 7) // 8) - req_ids = [] + self.nnode = (tp_size + self.max_chips_per_node) // self.max_chips_per_node num_running_requests = 0 tp_rank = self.local_rank % tp_size @@ -435,7 +433,6 @@ class PaddleDisWorkerProc: src=0, group=self.parallel_config.tp_group ) - self.insert_step = False req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) @@ -445,14 +442,13 @@ class PaddleDisWorkerProc: if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( self.fd_config.model_config.enable_mm and self.worker.exist_prefill() ): - if self.nnode > 1 and tp_size > self.max_chips_per_node: + if self.nnode > 1: self.task_queue.read_finish_flag.set(1) else: self.exist_task_signal.value[0] = ExistTaskStatus.EXIST - if tp_size > 1: - # Synchronize the signal for other workers - self._tp_barrier_wait() + # Synchronize the signal set by tp_rank0 visiable to other workers + self._tp_barrier_wait() if tp_size > 1 else None if self.fd_config.load_config.dynamic_load_weight: if self.parallel_config.enable_expert_parallel: @@ -483,13 +479,15 @@ class PaddleDisWorkerProc: if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: logger.info(f"Rank: {self.local_rank} Detected new requests.") - self.insert_step = True tasks, read_finish = self.task_queue.get_tasks() + # Only one of all tp_size client will get read_finish == True. if read_finish: - # Ensure that every worker get the task - self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY - self.task_queue.read_finish_flag.set(0) + # Reset the two signal. + if self.nnode > 1: + self.task_queue.read_finish_flag.set(0) + else: + self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY req_dicts = [] for req_dict, bsz in tasks: @@ -506,8 +504,7 @@ class PaddleDisWorkerProc: self.worker.preprocess_new_task(req_dicts, num_running_requests) if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()): - if self.ranks > 1: - self._tp_barrier_wait() + self._tp_barrier_wait() if tp_size > 1 else None time.sleep(0.001) continue