mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL] provide options for whether shutdown comm group after weights cleared (#5663)
Some checks failed
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [rl] provide options for whether shutdown comm group after weights cleared * [fix] fix args hardcode * [fix] change args type * [fix] add worker process args
This commit is contained in:
@@ -579,6 +579,8 @@ class ParallelConfig:
|
||||
self.use_internode_ll_two_stage: bool = False
|
||||
# disable sequence parallel moe
|
||||
self.disable_sequence_parallel_moe: bool = False
|
||||
# shutdown comm group if worker idle
|
||||
self.shutdown_comm_group_if_worker_idle: bool = None
|
||||
|
||||
self.pod_ip: str = None
|
||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||
@@ -596,6 +598,9 @@ class ParallelConfig:
|
||||
self.expert_parallel_size = 1
|
||||
self.use_ep = self.expert_parallel_size > 1
|
||||
|
||||
if self.shutdown_comm_group_if_worker_idle is None:
|
||||
self.shutdown_comm_group_if_worker_idle = not self.use_ep
|
||||
|
||||
# pd_disaggregation
|
||||
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
|
||||
|
||||
@@ -268,6 +268,11 @@ class EngineArgs:
|
||||
# This optimization is enabled by default, and can be disabled by using this flag.
|
||||
"""
|
||||
|
||||
shutdown_comm_group_if_worker_idle: bool = None
|
||||
"""
|
||||
Whether to shutdown the comm group when the weight is cleared.
|
||||
"""
|
||||
|
||||
engine_worker_queue_port: Optional[Union[int, str, list]] = None
|
||||
"""
|
||||
Port for worker queue communication.
|
||||
@@ -951,6 +956,12 @@ class EngineArgs:
|
||||
default=EngineArgs.chunked_moe_size,
|
||||
help="Chunked size of moe input.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--shutdown-comm-group-if-worker-idle",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=EngineArgs.shutdown_comm_group_if_worker_idle,
|
||||
help="Shutdown communication group when worker is idle.",
|
||||
)
|
||||
|
||||
# Load group
|
||||
load_group = parser.add_argument_group("Load Configuration")
|
||||
|
||||
@@ -590,6 +590,7 @@ class LLMEngine:
|
||||
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
|
||||
}
|
||||
for worker_flag, value in worker_store_true_flag.items():
|
||||
if value:
|
||||
|
||||
@@ -258,15 +258,16 @@ class DynamicWeightManager:
|
||||
value[self.rank] = status
|
||||
|
||||
@staticmethod
|
||||
def check_model_weights_status(model_weights_status, model_runner, pid):
|
||||
def check_model_weights_status(model_weights_status, model_runner, pid, block):
|
||||
"""
|
||||
check model weights status
|
||||
"""
|
||||
# logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||
while (
|
||||
model_weights_status.value[0] != ModelWeightsStatus.NORMAL
|
||||
and model_weights_status.value[0] != ModelWeightsStatus.CLEARED
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and (
|
||||
block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED
|
||||
):
|
||||
# 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建
|
||||
# 如果为非 block 模式,那么循环在权重更新或清理后均会退出
|
||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||
model_runner.clear_requests()
|
||||
|
||||
@@ -2770,7 +2770,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
# Clear parameters and Send single
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.dynamic_weight_manager.clear_parameters(
|
||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||
)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
@@ -2787,7 +2789,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
def update_parameters(self, pid):
|
||||
"""Dynamic model loader use to update parameters use for RL"""
|
||||
# Update parameters
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.dynamic_weight_manager.update_parameters(
|
||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||
)
|
||||
self.initialize_kv_cache()
|
||||
# Recapture CUDAGraph
|
||||
if self.use_cudagraph:
|
||||
|
||||
@@ -458,14 +458,18 @@ class PaddleDisWorkerProc:
|
||||
# model_weights_signal
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.local_engine_worker_queue_port,
|
||||
self.parallel_config.shutdown_comm_group_if_worker_idle,
|
||||
)
|
||||
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||
self.task_queue.clear_data()
|
||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱)
|
||||
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
|
||||
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
@@ -883,6 +887,12 @@ def parse_args():
|
||||
help="Configation of Rollout Routing Replay.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shutdown_comm_group_if_worker_idle",
|
||||
action="store_true",
|
||||
help="Shutdown comm group if worker idle.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user