diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f009d8a1d..01dbbc5d8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -579,6 +579,8 @@ class ParallelConfig: self.use_internode_ll_two_stage: bool = False # disable sequence parallel moe self.disable_sequence_parallel_moe: bool = False + # shutdown comm group if worker idle + self.shutdown_comm_group_if_worker_idle: bool = None self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). @@ -596,6 +598,9 @@ class ParallelConfig: self.expert_parallel_size = 1 self.use_ep = self.expert_parallel_size > 1 + if self.shutdown_comm_group_if_worker_idle is None: + self.shutdown_comm_group_if_worker_idle = not self.use_ep + # pd_disaggregation use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0)) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index b915999c7..6c77e5eb2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -268,6 +268,11 @@ class EngineArgs: # This optimization is enabled by default, and can be disabled by using this flag. """ + shutdown_comm_group_if_worker_idle: bool = None + """ + Whether to shutdown the comm group when the weight is cleared. + """ + engine_worker_queue_port: Optional[Union[int, str, list]] = None """ Port for worker queue communication. @@ -951,6 +956,12 @@ class EngineArgs: default=EngineArgs.chunked_moe_size, help="Chunked size of moe input.", ) + parallel_group.add_argument( + "--shutdown-comm-group-if-worker-idle", + action=argparse.BooleanOptionalAction, + default=EngineArgs.shutdown_comm_group_if_worker_idle, + help="Shutdown communication group when worker is idle.", + ) # Load group load_group = parser.add_argument_group("Load Configuration") diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index dbb425388..43eb18e47 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -590,6 +590,7 @@ class LLMEngine: "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, + "shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 5583108dd..1313fa2d1 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -258,15 +258,16 @@ class DynamicWeightManager: value[self.rank] = status @staticmethod - def check_model_weights_status(model_weights_status, model_runner, pid): + def check_model_weights_status(model_weights_status, model_runner, pid, block): """ check model weights status """ # logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}") - while ( - model_weights_status.value[0] != ModelWeightsStatus.NORMAL - and model_weights_status.value[0] != ModelWeightsStatus.CLEARED + while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and ( + block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED ): + # 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建 + # 如果为非 block 模式,那么循环在权重更新或清理后均会退出 if model_weights_status.value[0] == ModelWeightsStatus.UPDATING: logger.info("infer engine stopped! start to load new checkpoint...") model_runner.clear_requests() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a614a354f..f9fed8bfc 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2770,7 +2770,9 @@ class GPUModelRunner(ModelRunnerBase): if self.use_cudagraph: self.model.clear_grpah_opt_backend() # Clear parameters and Send single - self.dynamic_weight_manager.clear_parameters(pid) + self.dynamic_weight_manager.clear_parameters( + pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle + ) self.clear_cache() paddle.device.cuda.empty_cache() @@ -2787,7 +2789,9 @@ class GPUModelRunner(ModelRunnerBase): def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" # Update parameters - self.dynamic_weight_manager.update_parameters(pid) + self.dynamic_weight_manager.update_parameters( + pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle + ) self.initialize_kv_cache() # Recapture CUDAGraph if self.use_cudagraph: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index c37403f29..2153da1eb 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -458,14 +458,18 @@ class PaddleDisWorkerProc: # model_weights_signal self.worker.model_runner, self.parallel_config.local_engine_worker_queue_port, + self.parallel_config.shutdown_comm_group_if_worker_idle, ) logger.info(f"current task queue data: {self.task_queue.num_tasks()}") self.task_queue.clear_data() self.model_weights_signal[0] = ModelWeightsStatus.NORMAL logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.") - while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED: - time.sleep(0.01) - continue + + # 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱) + if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: + while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED: + time.sleep(0.01) + continue if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: logger.info(f"Rank: {self.local_rank} Detected new requests.") @@ -883,6 +887,12 @@ def parse_args(): help="Configation of Rollout Routing Replay.", ) + parser.add_argument( + "--shutdown_comm_group_if_worker_idle", + action="store_true", + help="Shutdown comm group if worker idle.", + ) + args = parser.parse_args() return args