[Feat] support mixed ep (#2969)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Support mixed ep

* fix comment

* fix comment

* update mixep

* fix conflict

* fix typo

* update

* fix typo

* fix code style

* fix conflict
This commit is contained in:
Longzhi Wang
2025-07-25 15:29:30 +08:00
committed by GitHub
parent 332154f504
commit 0700c90caa
4 changed files with 140 additions and 51 deletions

View File

@@ -18,7 +18,6 @@ from __future__ import annotations
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Optional from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -30,13 +29,24 @@ from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log") logger = get_logger("config", "config.log")
class MoEPhase(Enum): class MoEPhase:
""" """
The generation phase of the moe. The generation phase of the moe.
""" """
PREFILL = 1 def __init__(self, phase="prefill"):
DECODER = 2 self._phase = phase
@property
def phase(self):
return self._phase
@phase.setter
def phase(self, value):
if value not in ["prefill", "decode"]:
raise ValueError(f"The moe_phase is invalid, only support prefill and decode, but got {value}")
else:
self._phase = value
class ErnieArchitectures: class ErnieArchitectures:
@@ -146,7 +156,7 @@ class ParallelConfig:
): ):
self.sequence_parallel = False # Whether to enable sequence parallelism. self.sequence_parallel = False # Whether to enable sequence parallelism.
self.use_ep = False # Whether to enable Expert Parallelism self.use_ep = False # Whether to enable Expert Parallelism
self.moe_phase = MoEPhase.PREFILL # Generation phase self.moe_phase = MoEPhase("prefill") # Generation phase
self.msg_queue_id = 1 # mesage queue id self.msg_queue_id = 1 # mesage queue id
self.tensor_parallel_rank = 0 # TP rank ID self.tensor_parallel_rank = 0 # TP rank ID
@@ -210,11 +220,11 @@ class ParallelConfig:
setattr(self, key, value) setattr(self, key, value)
self.use_ep = args["expert_parallel_size"] > 1 self.use_ep = args["expert_parallel_size"] > 1
if self.splitwise_role == "mixed": if self.splitwise_role == "mixed":
self.moe_phase = MoEPhase.PREFILL self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "prefill": elif self.splitwise_role == "prefill":
self.moe_phase = MoEPhase.PREFILL self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "decode": elif self.splitwise_role == "decode":
self.moe_phase = MoEPhase.DECODER self.moe_phase = MoEPhase(phase="decode")
else: else:
raise NotImplementedError raise NotImplementedError

View File

@@ -43,9 +43,10 @@ class DeepEPEngine:
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
moe_phase: MoEPhase,
ep_size: int, ep_size: int,
ep_rank: int, ep_rank: int,
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False, async_finish: bool = False,
): ):
""" """
@@ -65,26 +66,44 @@ class DeepEPEngine:
self.hidden = hidden self.hidden = hidden
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size self.num_local_experts = num_experts // ep_size
self.moe_phase = moe_phase
self.async_finish = async_finish self.async_finish = async_finish
self.deepep_engine = None self.prefill_deepep_engine = None
self.decode_deepep_engine = None
if moe_phase == MoEPhase.DECODER: self.ep_config = Config(24, 6, 256)
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
# In mixed EP mode on a single node, we dynamically switch between
# high throughput and low latency modes.
if splitwise_role == "mixed":
# decode engine
logger.info("Initializing Low Latency Buffer") logger.info("Initializing Low Latency Buffer")
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.get_low_latency_buffer() self.get_low_latency_buffer()
elif moe_phase == MoEPhase.PREFILL: # prefill engine
self.deepep_engine = deep_ep.Buffer( self.prefill_deepep_engine = deep_ep.Buffer(
self.group, self.group,
int(5e8), int(5e8),
0, 0,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=1, num_qps_per_rank=1,
) )
self.ep_config = Config(24, 6, 256) # In disaggregated mode on mutiple nodes, we either use
# high throughput mode or low latency mode.
else: else:
raise ValueError(f"Unknown generation phase {moe_phase}") if moe_phase.phase == "decode":
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
elif moe_phase.phase == "prefill":
self.prefill_deepep_engine = deep_ep.Buffer(
self.group,
int(5e8),
0,
low_latency_mode=False,
num_qps_per_rank=1,
)
else:
raise ValueError(f"Unknown generation phase {moe_phase}")
def get_low_latency_buffer(self): def get_low_latency_buffer(self):
""" """
@@ -105,14 +124,14 @@ class DeepEPEngine:
) )
# Allocate a buffer if not existed or not enough buffer size # Allocate a buffer if not existed or not enough buffer size
if ( if (
self.deepep_engine is None self.decode_deepep_engine is None
or self.deepep_engine.group != self.group or self.decode_deepep_engine.group != self.group
or not self.deepep_engine.low_latency_mode or not self.decode_deepep_engine.low_latency_mode
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
): ):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts # NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0 assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer( self.decode_deepep_engine = deep_ep.Buffer(
self.group, self.group,
0, 0,
num_rdma_bytes, num_rdma_bytes,
@@ -149,7 +168,7 @@ class DeepEPEngine:
handle, handle,
_, _,
dispatch_hook, dispatch_hook,
) = self.deepep_engine.low_latency_dispatch( ) = self.decode_deepep_engine.low_latency_dispatch(
hidden_states, hidden_states,
topk_idx, topk_idx,
expertwise_scale, expertwise_scale,
@@ -174,8 +193,22 @@ class DeepEPEngine:
Return: Return:
combined_hidden_states: [num_tokens, hidden] combined_hidden_states: [num_tokens, hidden]
""" """
# TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed
(
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
num_experts,
) = handle
handle = (
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
None,
num_experts,
)
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine( combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
@@ -189,7 +222,7 @@ class DeepEPEngine:
""" """
clean_low_latency_buffer clean_low_latency_buffer
""" """
self.deepep_engine.clean_low_latency_buffer( self.decode_deepep_engine.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
) )
@@ -197,7 +230,11 @@ class DeepEPEngine:
""" """
barrier_all barrier_all
""" """
self.deepep_engine.barrier_all() if self.prefill_deepep_engine is not None:
self.prefill_deepep_engine.barrier_all()
if self.decode_deepep_engine is not None:
self.decode_deepep_engine.barrier_all()
class EPRunner: class EPRunner:
@@ -210,6 +247,7 @@ class EPRunner:
top_k: int, top_k: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
splitwise_role: str,
moe_phase: MoEPhase, moe_phase: MoEPhase,
num_max_dispatch_tokens_per_rank: int = 1, num_max_dispatch_tokens_per_rank: int = 1,
ep_size: int = 1, ep_size: int = 1,
@@ -223,9 +261,10 @@ class EPRunner:
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
hidden=hidden, hidden=hidden,
num_experts=num_experts + redundant_experts_num, num_experts=num_experts + redundant_experts_num,
moe_phase=moe_phase,
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
splitwise_role=splitwise_role,
moe_phase=moe_phase,
) )
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -286,15 +325,19 @@ class EPPrefillRunner(EPRunner):
top_k: int, top_k: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
splitwise_role: str,
ep_size: int = 1, ep_size: int = 1,
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
moe_phase: MoEPhase = MoEPhase("prefill"),
): ):
super().__init__( super().__init__(
top_k, top_k,
hidden, hidden,
num_experts, num_experts,
MoEPhase.PREFILL, splitwise_role,
moe_phase,
num_max_dispatch_tokens_per_rank=256,
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num, redundant_experts_num=redundant_experts_num,
@@ -314,7 +357,7 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_expert, num_tokens_per_expert,
is_token_in_rank, is_token_in_rank,
_, _,
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) ) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
x_scale_tensor = kwargs.get("x_scale_tensor", None) x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = { dispatch_args = {
@@ -327,7 +370,7 @@ class EPPrefillRunner(EPRunner):
"topk_idx": topk_idx, "topk_idx": topk_idx,
"topk_weights": topk_weights, "topk_weights": topk_weights,
} }
return self.ep_engine.deepep_engine.dispatch(**dispatch_args) return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
def combine( def combine(
self, self,
@@ -342,7 +385,7 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish, "async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights, "topk_weights": recv_topk_weights,
} }
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args) fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
return fused_moe_out return fused_moe_out
@@ -357,16 +400,19 @@ class EPDecoderRunner(EPRunner):
top_k: int, top_k: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
splitwise_role: str,
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
ep_size: int = 1, ep_size: int = 1,
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
moe_phase: MoEPhase = MoEPhase("decode"),
): ):
super().__init__( super().__init__(
top_k, top_k,
hidden, hidden,
num_experts, num_experts,
MoEPhase.DECODER, splitwise_role,
moe_phase,
num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank,
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,

View File

@@ -19,8 +19,6 @@ from abc import abstractmethod
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy.config import MoEPhase
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
@@ -45,29 +43,54 @@ class MoEMethodBase(QuantMethodBase):
Init EP related module Init EP related module
""" """
if layer.ep_size > 1: if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase == MoEPhase.DECODER: if layer.fd_config.parallel_config.splitwise_role == "mixed":
from .ep import EPDecoderRunner from .ep import EPDecoderRunner, EPPrefillRunner
self.ep_prefill_runner = EPPrefillRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.fd_config.parallel_config.splitwise_role,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
self.ep_decoder_runner = EPDecoderRunner( self.ep_decoder_runner = EPDecoderRunner(
layer.top_k, layer.top_k,
layer.hidden_size, layer.hidden_size,
layer.num_experts, layer.num_experts,
layer.fd_config.parallel_config.splitwise_role,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
) )
else: else:
from .ep import EPPrefillRunner if layer.fd_config.parallel_config.moe_phase == "prefill":
from .ep import EPPrefillRunner
self.ep_prefill_runner = EPPrefillRunner( self.ep_prefill_runner = EPPrefillRunner(
layer.top_k, layer.top_k,
layer.hidden_size, layer.hidden_size,
layer.num_experts, layer.num_experts,
layer.ep_size, layer.fd_config.parallel_config.splitwise_role,
layer.ep_rank, layer.ep_size,
layer.fd_config.model_config.redundant_experts_num, layer.ep_rank,
) layer.fd_config.model_config.redundant_experts_num,
)
else:
from .ep import EPDecoderRunner
self.ep_decoder_runner = EPDecoderRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.moe_config.num_max_dispatch_tokens_per_rank,
layer.fd_config.parallel_config.splitwise_role,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
""" """
@@ -141,7 +164,7 @@ class MoEMethodBase(QuantMethodBase):
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
""" """
if layer.ep_size > 1: if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase == MoEPhase.PREFILL: if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
return self.apply_ep_prefill(layer, x, gate_out) return self.apply_ep_prefill(layer, x, gate_out)
else: else:
return self.apply_ep_decode(layer, x, gate_out) return self.apply_ep_decode(layer, x, gate_out)

View File

@@ -794,6 +794,14 @@ class GPUModelRunner(ModelRunnerBase):
# Update Batch type for cuda graph # Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
is_decode_batch_list = []
paddle.distributed.all_gather_object(is_decode_batch_list, is_decode_batch)
is_decode_batch = all(is_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if is_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
# Initialzie attention meta data # Initialzie attention meta data
@@ -1163,16 +1171,18 @@ class GPUModelRunner(ModelRunnerBase):
We plan to replace it with 'ModelForwardBatch'. We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors: intermediate_tensors:
""" """
# NOTE(wufeisheng): For Expert Parallelism
if not self.not_need_stop():
self._execute_empty_input()
return None
# 1. Prepare inputs of model and sampler. # 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch) skip_idx_list = self._get_skip_idx(model_forward_batch)
self._prepare_inputs() self._prepare_inputs()
self.sampler.pre_process(skip_idx_list) self.sampler.pre_process(skip_idx_list)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop():
self._execute_empty_input()
return None
# 2. Padding inputs for cuda graph # 2. Padding inputs for cuda graph
self.padding_cudagraph_inputs() self.padding_cudagraph_inputs()