[Feature] support model weight update in ep (#3802)

* Update config.py

* Update ep.py

* Update fused_moe_backend_base.py

* Update dynamic_weight_manager.py

* Update worker_process.py

* fix ci
This commit is contained in:
ltd0924
2025-09-02 20:52:47 +08:00
committed by GitHub
parent d1d063e4af
commit 0f42771a84
5 changed files with 43 additions and 19 deletions

View File

@@ -350,8 +350,8 @@ class ParallelConfig:
) )
) )
# same ep group id # same ep group id
# (TODO:gaoziyuan move this gid config to ep.py)
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset) dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size))
logger.info( logger.info(
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}." f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
) )

View File

@@ -78,6 +78,7 @@ class DeepEPEngine:
splitwise_role: str, splitwise_role: str,
moe_phase: MoEPhase, moe_phase: MoEPhase,
async_finish: bool = False, async_finish: bool = False,
group=None,
): ):
""" """
Initialize the DeepEP engine. Initialize the DeepEP engine.
@@ -90,7 +91,9 @@ class DeepEPEngine:
num_experts: The number of experts. num_experts: The number of experts.
""" """
# TODO(@wufeisheng): Support configurable EP size # TODO(@wufeisheng): Support configurable EP size
self.group = paddle.distributed.new_group(range(ep_size)) if group is None:
group = paddle.distributed.new_group(range(ep_size))
self.group = group
self.ep_size = ep_size self.ep_size = ep_size
self.rank_id = ep_rank self.rank_id = ep_rank
self.hidden = hidden self.hidden = hidden
@@ -277,6 +280,7 @@ class EPRunner:
ep_size: int = 1, ep_size: int = 1,
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
ep_group=None,
): ):
self.top_k = top_k self.top_k = top_k
self.num_experts = num_experts self.num_experts = num_experts
@@ -289,6 +293,7 @@ class EPRunner:
ep_rank=ep_rank, ep_rank=ep_rank,
splitwise_role=splitwise_role, splitwise_role=splitwise_role,
moe_phase=moe_phase, moe_phase=moe_phase,
group=ep_group,
) )
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -368,6 +373,7 @@ class EPPrefillRunner(EPRunner):
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
moe_phase: MoEPhase = MoEPhase("prefill"), moe_phase: MoEPhase = MoEPhase("prefill"),
ep_group=None,
): ):
super().__init__( super().__init__(
top_k, top_k,
@@ -379,6 +385,7 @@ class EPPrefillRunner(EPRunner):
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num, redundant_experts_num=redundant_experts_num,
ep_group=ep_group,
) )
def dispatch( def dispatch(
@@ -445,6 +452,7 @@ class EPDecoderRunner(EPRunner):
ep_size: int = 1, ep_size: int = 1,
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
ep_group=None,
moe_phase: MoEPhase = MoEPhase("decode"), moe_phase: MoEPhase = MoEPhase("decode"),
): ):
super().__init__( super().__init__(
@@ -457,6 +465,7 @@ class EPDecoderRunner(EPRunner):
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num, redundant_experts_num=redundant_experts_num,
ep_group=ep_group,
) )
def dispatch( def dispatch(

View File

@@ -58,6 +58,7 @@ class MoEMethodBase(QuantMethodBase):
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group,
) )
self.ep_decoder_runner = EPDecoderRunner( self.ep_decoder_runner = EPDecoderRunner(
layer.top_k, layer.top_k,
@@ -68,6 +69,7 @@ class MoEMethodBase(QuantMethodBase):
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group,
) )
else: else:
if layer.fd_config.parallel_config.moe_phase.phase == "prefill": if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
@@ -82,6 +84,7 @@ class MoEMethodBase(QuantMethodBase):
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group,
) )
else: else:
from .ep import EPDecoderRunner from .ep import EPDecoderRunner
@@ -95,6 +98,7 @@ class MoEMethodBase(QuantMethodBase):
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
ep_group=layer.fd_config.parallel_config.ep_group,
) )
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:

View File

@@ -63,7 +63,9 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
if not self.first_load: if not self.first_load:
paddle.distributed.restart_process_group() paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
strategy_handlers = { strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot, "ipc_snapshot": self._update_ipc_snapshot,
@@ -110,8 +112,12 @@ class DynamicWeightManager:
param._clear_data() param._clear_data()
self._verify_parameters("clearance") self._verify_parameters("clearance")
if self.nranks > 1: if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier() paddle.distributed.barrier(self.parallel_config.tp_group)
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group() paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, -2) self._update_shared_status(pid, -2)
@@ -141,8 +147,8 @@ class DynamicWeightManager:
def _finalize_update(self, pid: int): def _finalize_update(self, pid: int):
"""Finalize update process with verification.""" """Finalize update process with verification."""
self._verify_parameters("update") self._verify_parameters("update")
if self.nranks > 1: if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier() paddle.distributed.barrier(self.parallel_config.tp_group)
if not self.first_load: if not self.first_load:
self._update_shared_status(pid, 0) self._update_shared_status(pid, 0)
self.first_load = False self.first_load = False

View File

@@ -254,27 +254,25 @@ class PaddleDisWorkerProc:
""" """
# Currently, only support single node # Currently, only support single node
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8) self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
req_ids = [] req_ids = []
num_running_requests = 0 num_running_requests = 0
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
while True: while True:
if self.local_rank == 0: if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != 0: if self.model_weights_status.value[0] != 0:
self.exist_task_signal.value[0] = 2 self.model_weights_signal[0] = int(self.model_weights_status.value[0])
else: if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
self.exist_task_signal.value[0] = 0 paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
if self.fd_config.load_config.dynamic_load_weight:
if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
# Synchronize before updating weights
paddle.distributed.barrier(self.parallel_config.tp_group)
self.insert_step = False self.insert_step = False
req_dicts = None req_dicts = None
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time()) self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue # The first worker detects whether there are tasks in the task queue
if self.local_rank % mp_num_per_node == 0: if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.task_queue.num_tasks() > 0: if self.task_queue.num_tasks() > 0:
# VL only support 1 batch to prefill # VL only support 1 batch to prefill
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
@@ -290,16 +288,23 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier(self.parallel_config.tp_group) paddle.distributed.barrier(self.parallel_config.tp_group)
if self.fd_config.load_config.dynamic_load_weight: if self.fd_config.load_config.dynamic_load_weight:
if self.exist_task_signal.value[0] == 2: if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != 0:
logger.info(f"Rank: {self.local_rank} has updated parameters.")
from fastdeploy.rl.dynamic_weight_manager import ( from fastdeploy.rl.dynamic_weight_manager import (
DynamicWeightManager, DynamicWeightManager,
) )
self.model_weights_status.value[0] = self.model_weights_signal[0]
DynamicWeightManager.check_model_weights_status( DynamicWeightManager.check_model_weights_status(
self.model_weights_status, self.model_weights_status,
self.worker.model_runner, self.worker.model_runner,
self.parallel_config.engine_pid, self.parallel_config.engine_worker_queue_port,
) )
self.model_weights_signal[0] = 0
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1: if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.") logger.info(f"Rank: {self.local_rank} Detected new requests.")