[Feature] support model weight update in ep (#3802)

* Update config.py

* Update ep.py

* Update fused_moe_backend_base.py

* Update dynamic_weight_manager.py

* Update worker_process.py

* fix ci
This commit is contained in:
ltd0924
2025-09-02 20:52:47 +08:00
committed by GitHub
parent d1d063e4af
commit 0f42771a84
5 changed files with 43 additions and 19 deletions

View File

@@ -63,7 +63,9 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache()
if not self.first_load:
paddle.distributed.restart_process_group()
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
@@ -110,8 +112,12 @@ class DynamicWeightManager:
param._clear_data()
self._verify_parameters("clearance")
if self.nranks > 1:
paddle.distributed.barrier()
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, -2)
@@ -141,8 +147,8 @@ class DynamicWeightManager:
def _finalize_update(self, pid: int):
"""Finalize update process with verification."""
self._verify_parameters("update")
if self.nranks > 1:
paddle.distributed.barrier()
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
if not self.first_load:
self._update_shared_status(pid, 0)
self.first_load = False