[Feature] support model weight update in ep (#3765)

* support model weight update in ep

* support model weight update in ep

* support model weight update in ep

* support model weight update in ep

* Update fused_moe_backend_base.py

* Update worker_process.py

* Update worker_process.py

* Update dynamic_weight_manager.py
This commit is contained in:
ltd0924
2025-09-02 17:16:03 +08:00
committed by GitHub
parent 1908465542
commit 905d89e42f
5 changed files with 47 additions and 22 deletions

View File

@@ -63,7 +63,9 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache()
if not self.first_load:
paddle.distributed.restart_process_group()
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
@@ -110,9 +112,12 @@ class DynamicWeightManager:
param._clear_data()
self._verify_parameters("clearance")
if self.nranks > 1:
paddle.distributed.barrier()
paddle.distributed.shutdown_process_group()
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
self._update_shared_status(pid, -2)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
@@ -141,8 +146,8 @@ class DynamicWeightManager:
def _finalize_update(self, pid: int):
"""Finalize update process with verification."""
self._verify_parameters("update")
if self.nranks > 1:
paddle.distributed.barrier()
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
if not self.first_load:
self._update_shared_status(pid, 0)
self.first_load = False