[RL] provide options for whether shutdown comm group after weights cleared (#5663)
Some checks failed
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [rl] provide options for whether shutdown comm group after weights cleared

* [fix] fix args hardcode

* [fix] change args type

* [fix] add worker process args
This commit is contained in:
Yonghua Li
2025-12-19 23:06:48 +08:00
committed by GitHub
parent fe55baae47
commit 4f830aa505
6 changed files with 41 additions and 9 deletions

View File

@@ -579,6 +579,8 @@ class ParallelConfig:
self.use_internode_ll_two_stage: bool = False
# disable sequence parallel moe
self.disable_sequence_parallel_moe: bool = False
# shutdown comm group if worker idle
self.shutdown_comm_group_if_worker_idle: bool = None
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
@@ -596,6 +598,9 @@ class ParallelConfig:
self.expert_parallel_size = 1
self.use_ep = self.expert_parallel_size > 1
if self.shutdown_comm_group_if_worker_idle is None:
self.shutdown_comm_group_if_worker_idle = not self.use_ep
# pd_disaggregation
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))

View File

@@ -268,6 +268,11 @@ class EngineArgs:
# This optimization is enabled by default, and can be disabled by using this flag.
"""
shutdown_comm_group_if_worker_idle: bool = None
"""
Whether to shutdown the comm group when the weight is cleared.
"""
engine_worker_queue_port: Optional[Union[int, str, list]] = None
"""
Port for worker queue communication.
@@ -951,6 +956,12 @@ class EngineArgs:
default=EngineArgs.chunked_moe_size,
help="Chunked size of moe input.",
)
parallel_group.add_argument(
"--shutdown-comm-group-if-worker-idle",
action=argparse.BooleanOptionalAction,
default=EngineArgs.shutdown_comm_group_if_worker_idle,
help="Shutdown communication group when worker is idle.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")

View File

@@ -590,6 +590,7 @@ class LLMEngine:
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
"enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
}
for worker_flag, value in worker_store_true_flag.items():
if value:

View File

@@ -258,15 +258,16 @@ class DynamicWeightManager:
value[self.rank] = status
@staticmethod
def check_model_weights_status(model_weights_status, model_runner, pid):
def check_model_weights_status(model_weights_status, model_runner, pid, block):
"""
check model weights status
"""
# logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
while (
model_weights_status.value[0] != ModelWeightsStatus.NORMAL
and model_weights_status.value[0] != ModelWeightsStatus.CLEARED
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and (
block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED
):
# 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建
# 如果为非 block 模式,那么循环在权重更新或清理后均会退出
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
logger.info("infer engine stopped! start to load new checkpoint...")
model_runner.clear_requests()

View File

@@ -2770,7 +2770,9 @@ class GPUModelRunner(ModelRunnerBase):
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(pid)
self.dynamic_weight_manager.clear_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
)
self.clear_cache()
paddle.device.cuda.empty_cache()
@@ -2787,7 +2789,9 @@ class GPUModelRunner(ModelRunnerBase):
def update_parameters(self, pid):
"""Dynamic model loader use to update parameters use for RL"""
# Update parameters
self.dynamic_weight_manager.update_parameters(pid)
self.dynamic_weight_manager.update_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
)
self.initialize_kv_cache()
# Recapture CUDAGraph
if self.use_cudagraph:

View File

@@ -458,14 +458,18 @@ class PaddleDisWorkerProc:
# model_weights_signal
self.worker.model_runner,
self.parallel_config.local_engine_worker_queue_port,
self.parallel_config.shutdown_comm_group_if_worker_idle,
)
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
time.sleep(0.01)
continue
# 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱)
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
time.sleep(0.01)
continue
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
@@ -883,6 +887,12 @@ def parse_args():
help="Configation of Rollout Routing Replay.",
)
parser.add_argument(
"--shutdown_comm_group_if_worker_idle",
action="store_true",
help="Shutdown comm group if worker idle.",
)
args = parser.parse_args()
return args