mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL] provide options for whether shutdown comm group after weights cleared (#5663)
Some checks failed
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [rl] provide options for whether shutdown comm group after weights cleared * [fix] fix args hardcode * [fix] change args type * [fix] add worker process args
This commit is contained in:
@@ -579,6 +579,8 @@ class ParallelConfig:
|
|||||||
self.use_internode_ll_two_stage: bool = False
|
self.use_internode_ll_two_stage: bool = False
|
||||||
# disable sequence parallel moe
|
# disable sequence parallel moe
|
||||||
self.disable_sequence_parallel_moe: bool = False
|
self.disable_sequence_parallel_moe: bool = False
|
||||||
|
# shutdown comm group if worker idle
|
||||||
|
self.shutdown_comm_group_if_worker_idle: bool = None
|
||||||
|
|
||||||
self.pod_ip: str = None
|
self.pod_ip: str = None
|
||||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||||
@@ -596,6 +598,9 @@ class ParallelConfig:
|
|||||||
self.expert_parallel_size = 1
|
self.expert_parallel_size = 1
|
||||||
self.use_ep = self.expert_parallel_size > 1
|
self.use_ep = self.expert_parallel_size > 1
|
||||||
|
|
||||||
|
if self.shutdown_comm_group_if_worker_idle is None:
|
||||||
|
self.shutdown_comm_group_if_worker_idle = not self.use_ep
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||||
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
|
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
|
||||||
|
|||||||
@@ -268,6 +268,11 @@ class EngineArgs:
|
|||||||
# This optimization is enabled by default, and can be disabled by using this flag.
|
# This optimization is enabled by default, and can be disabled by using this flag.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
shutdown_comm_group_if_worker_idle: bool = None
|
||||||
|
"""
|
||||||
|
Whether to shutdown the comm group when the weight is cleared.
|
||||||
|
"""
|
||||||
|
|
||||||
engine_worker_queue_port: Optional[Union[int, str, list]] = None
|
engine_worker_queue_port: Optional[Union[int, str, list]] = None
|
||||||
"""
|
"""
|
||||||
Port for worker queue communication.
|
Port for worker queue communication.
|
||||||
@@ -951,6 +956,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.chunked_moe_size,
|
default=EngineArgs.chunked_moe_size,
|
||||||
help="Chunked size of moe input.",
|
help="Chunked size of moe input.",
|
||||||
)
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--shutdown-comm-group-if-worker-idle",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=EngineArgs.shutdown_comm_group_if_worker_idle,
|
||||||
|
help="Shutdown communication group when worker is idle.",
|
||||||
|
)
|
||||||
|
|
||||||
# Load group
|
# Load group
|
||||||
load_group = parser.add_argument_group("Load Configuration")
|
load_group = parser.add_argument_group("Load Configuration")
|
||||||
|
|||||||
@@ -590,6 +590,7 @@ class LLMEngine:
|
|||||||
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
||||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||||
|
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
|
||||||
}
|
}
|
||||||
for worker_flag, value in worker_store_true_flag.items():
|
for worker_flag, value in worker_store_true_flag.items():
|
||||||
if value:
|
if value:
|
||||||
|
|||||||
@@ -258,15 +258,16 @@ class DynamicWeightManager:
|
|||||||
value[self.rank] = status
|
value[self.rank] = status
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_model_weights_status(model_weights_status, model_runner, pid):
|
def check_model_weights_status(model_weights_status, model_runner, pid, block):
|
||||||
"""
|
"""
|
||||||
check model weights status
|
check model weights status
|
||||||
"""
|
"""
|
||||||
# logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
# logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||||
while (
|
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and (
|
||||||
model_weights_status.value[0] != ModelWeightsStatus.NORMAL
|
block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED
|
||||||
and model_weights_status.value[0] != ModelWeightsStatus.CLEARED
|
|
||||||
):
|
):
|
||||||
|
# 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建
|
||||||
|
# 如果为非 block 模式,那么循环在权重更新或清理后均会退出
|
||||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||||
model_runner.clear_requests()
|
model_runner.clear_requests()
|
||||||
|
|||||||
@@ -2770,7 +2770,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.model.clear_grpah_opt_backend()
|
self.model.clear_grpah_opt_backend()
|
||||||
# Clear parameters and Send single
|
# Clear parameters and Send single
|
||||||
self.dynamic_weight_manager.clear_parameters(pid)
|
self.dynamic_weight_manager.clear_parameters(
|
||||||
|
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||||
|
)
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
|
|
||||||
@@ -2787,7 +2789,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
def update_parameters(self, pid):
|
def update_parameters(self, pid):
|
||||||
"""Dynamic model loader use to update parameters use for RL"""
|
"""Dynamic model loader use to update parameters use for RL"""
|
||||||
# Update parameters
|
# Update parameters
|
||||||
self.dynamic_weight_manager.update_parameters(pid)
|
self.dynamic_weight_manager.update_parameters(
|
||||||
|
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||||
|
)
|
||||||
self.initialize_kv_cache()
|
self.initialize_kv_cache()
|
||||||
# Recapture CUDAGraph
|
# Recapture CUDAGraph
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
|
|||||||
@@ -458,11 +458,15 @@ class PaddleDisWorkerProc:
|
|||||||
# model_weights_signal
|
# model_weights_signal
|
||||||
self.worker.model_runner,
|
self.worker.model_runner,
|
||||||
self.parallel_config.local_engine_worker_queue_port,
|
self.parallel_config.local_engine_worker_queue_port,
|
||||||
|
self.parallel_config.shutdown_comm_group_if_worker_idle,
|
||||||
)
|
)
|
||||||
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||||
self.task_queue.clear_data()
|
self.task_queue.clear_data()
|
||||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||||
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||||
|
|
||||||
|
# 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱)
|
||||||
|
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
|
||||||
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
|
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
continue
|
continue
|
||||||
@@ -883,6 +887,12 @@ def parse_args():
|
|||||||
help="Configation of Rollout Routing Replay.",
|
help="Configation of Rollout Routing Replay.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--shutdown_comm_group_if_worker_idle",
|
||||||
|
action="store_true",
|
||||||
|
help="Shutdown comm group if worker idle.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user