From ea16c82b430a32ab9ca82db633ca0a2c6780ceb8 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Fri, 19 Dec 2025 23:18:03 +0800 Subject: [PATCH] [Cherry-Pick] [RL] provide options for whether shutdown comm group after weights cleared (#5663) (#5664) * [rl] provide options for whether shutdown comm group after weights cleared * [fix] fix args hardcode * [fix] change args type * [fix] add worker process args --- fastdeploy/config.py | 5 +++++ fastdeploy/engine/args_utils.py | 11 +++++++++++ fastdeploy/engine/engine.py | 1 + fastdeploy/rl/dynamic_weight_manager.py | 9 +++++---- fastdeploy/worker/gpu_model_runner.py | 8 ++++++-- fastdeploy/worker/worker_process.py | 16 +++++++++++++--- 6 files changed, 41 insertions(+), 9 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 0539ca7b2..a2fb35e84 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -566,6 +566,8 @@ class ParallelConfig: self.use_internode_ll_two_stage: bool = False # disable sequence parallel moe self.disable_sequence_parallel_moe: bool = False + # shutdown comm group if worker idle + self.shutdown_comm_group_if_worker_idle: bool = None self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). @@ -585,6 +587,9 @@ class ParallelConfig: self.expert_parallel_size = 1 self.use_ep = self.expert_parallel_size > 1 + if self.shutdown_comm_group_if_worker_idle is None: + self.shutdown_comm_group_if_worker_idle = not self.use_ep + # pd_disaggregation use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0)) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d2d7c6f90..edfb6fdb1 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -266,6 +266,11 @@ class EngineArgs: # This optimization is enabled by default, and can be disabled by using this flag. """ + shutdown_comm_group_if_worker_idle: bool = None + """ + Whether to shutdown the comm group when the weight is cleared. + """ + engine_worker_queue_port: str = "0" """ Port for worker queue communication. @@ -906,6 +911,12 @@ class EngineArgs: default=EngineArgs.chunked_moe_size, help="Chunked size of moe input.", ) + parallel_group.add_argument( + "--shutdown-comm-group-if-worker-idle", + action=argparse.BooleanOptionalAction, + default=EngineArgs.shutdown_comm_group_if_worker_idle, + help="Shutdown communication group when worker is idle.", + ) # Load group load_group = parser.add_argument_group("Load Configuration") diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index a753775c6..3762fe5af 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -583,6 +583,7 @@ class LLMEngine: "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, + "shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index a865b9c62..bee87de3b 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -267,15 +267,16 @@ class DynamicWeightManager: value[self.rank] = status @staticmethod - def check_model_weights_status(model_weights_status, model_runner, pid): + def check_model_weights_status(model_weights_status, model_runner, pid, block): """ check model weights status """ # logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}") - while ( - model_weights_status.value[0] != ModelWeightsStatus.NORMAL - and model_weights_status.value[0] != ModelWeightsStatus.CLEARED + while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and ( + block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED ): + # 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建 + # 如果为非 block 模式,那么循环在权重更新或清理后均会退出 if model_weights_status.value[0] == ModelWeightsStatus.UPDATING: logger.info("infer engine stopped! start to load new checkpoint...") model_runner.clear_requests() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 418fef909..56e4ceb42 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2691,7 +2691,9 @@ class GPUModelRunner(ModelRunnerBase): if self.use_cudagraph: self.model.clear_grpah_opt_backend() # Clear parameters and Send single - self.dynamic_weight_manager.clear_parameters(pid) + self.dynamic_weight_manager.clear_parameters( + pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle + ) self.clear_cache() paddle.device.cuda.empty_cache() @@ -2708,7 +2710,9 @@ class GPUModelRunner(ModelRunnerBase): def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" # Update parameters - self.dynamic_weight_manager.update_parameters(pid) + self.dynamic_weight_manager.update_parameters( + pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle + ) self.initialize_kv_cache() # Recapture CUDAGraph if self.use_cudagraph: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 9092bd3ba..e00d92c09 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -466,14 +466,18 @@ class PaddleDisWorkerProc: # model_weights_signal self.worker.model_runner, self.parallel_config.engine_worker_queue_port, + self.parallel_config.shutdown_comm_group_if_worker_idle, ) logger.info(f"current task queue data: {self.task_queue.num_tasks()}") self.task_queue.clear_data() self.model_weights_signal[0] = ModelWeightsStatus.NORMAL logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.") - while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED: - time.sleep(0.01) - continue + + # 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱) + if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: + while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED: + time.sleep(0.01) + continue if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: logger.info(f"Rank: {self.local_rank} Detected new requests.") @@ -890,6 +894,12 @@ def parse_args(): help="Configation of Rollout Routing Replay.", ) + parser.add_argument( + "--shutdown_comm_group_if_worker_idle", + action="store_true", + help="Shutdown comm group if worker idle.", + ) + args = parser.parse_args() return args