[Feat] support mixed ep (#2969)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Support mixed ep

* fix comment

* fix comment

* update mixep

* fix conflict

* fix typo

* update

* fix typo

* fix code style

* fix conflict
This commit is contained in:
Longzhi Wang
2025-07-25 15:29:30 +08:00
committed by GitHub
parent 332154f504
commit 0700c90caa
4 changed files with 140 additions and 51 deletions

View File

@@ -18,7 +18,6 @@ from __future__ import annotations
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -30,13 +29,24 @@ from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log")
class MoEPhase(Enum):
class MoEPhase:
"""
The generation phase of the moe.
"""
PREFILL = 1
DECODER = 2
def __init__(self, phase="prefill"):
self._phase = phase
@property
def phase(self):
return self._phase
@phase.setter
def phase(self, value):
if value not in ["prefill", "decode"]:
raise ValueError(f"The moe_phase is invalid, only support prefill and decode, but got {value}")
else:
self._phase = value
class ErnieArchitectures:
@@ -146,7 +156,7 @@ class ParallelConfig:
):
self.sequence_parallel = False # Whether to enable sequence parallelism.
self.use_ep = False # Whether to enable Expert Parallelism
self.moe_phase = MoEPhase.PREFILL # Generation phase
self.moe_phase = MoEPhase("prefill") # Generation phase
self.msg_queue_id = 1 # mesage queue id
self.tensor_parallel_rank = 0 # TP rank ID
@@ -210,11 +220,11 @@ class ParallelConfig:
setattr(self, key, value)
self.use_ep = args["expert_parallel_size"] > 1
if self.splitwise_role == "mixed":
self.moe_phase = MoEPhase.PREFILL
self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "prefill":
self.moe_phase = MoEPhase.PREFILL
self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "decode":
self.moe_phase = MoEPhase.DECODER
self.moe_phase = MoEPhase(phase="decode")
else:
raise NotImplementedError

View File

@@ -43,9 +43,10 @@ class DeepEPEngine:
num_max_dispatch_tokens_per_rank: int,
hidden: int,
num_experts: int,
moe_phase: MoEPhase,
ep_size: int,
ep_rank: int,
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False,
):
"""
@@ -65,24 +66,42 @@ class DeepEPEngine:
self.hidden = hidden
self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size
self.moe_phase = moe_phase
self.async_finish = async_finish
self.deepep_engine = None
self.prefill_deepep_engine = None
self.decode_deepep_engine = None
if moe_phase == MoEPhase.DECODER:
logger.info("Initializing Low Latency Buffer")
self.ep_config = Config(24, 6, 256)
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
# In mixed EP mode on a single node, we dynamically switch between
# high throughput and low latency modes.
if splitwise_role == "mixed":
# decode engine
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
elif moe_phase == MoEPhase.PREFILL:
self.deepep_engine = deep_ep.Buffer(
# prefill engine
self.prefill_deepep_engine = deep_ep.Buffer(
self.group,
int(5e8),
0,
low_latency_mode=False,
num_qps_per_rank=1,
)
# In disaggregated mode on mutiple nodes, we either use
# high throughput mode or low latency mode.
else:
if moe_phase.phase == "decode":
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
elif moe_phase.phase == "prefill":
self.prefill_deepep_engine = deep_ep.Buffer(
self.group,
int(5e8),
0,
low_latency_mode=False,
num_qps_per_rank=1,
)
self.ep_config = Config(24, 6, 256)
else:
raise ValueError(f"Unknown generation phase {moe_phase}")
@@ -105,14 +124,14 @@ class DeepEPEngine:
)
# Allocate a buffer if not existed or not enough buffer size
if (
self.deepep_engine is None
or self.deepep_engine.group != self.group
or not self.deepep_engine.low_latency_mode
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
self.decode_deepep_engine is None
or self.decode_deepep_engine.group != self.group
or not self.decode_deepep_engine.low_latency_mode
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer(
self.decode_deepep_engine = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
@@ -149,7 +168,7 @@ class DeepEPEngine:
handle,
_,
dispatch_hook,
) = self.deepep_engine.low_latency_dispatch(
) = self.decode_deepep_engine.low_latency_dispatch(
hidden_states,
topk_idx,
expertwise_scale,
@@ -174,8 +193,22 @@ class DeepEPEngine:
Return:
combined_hidden_states: [num_tokens, hidden]
"""
# TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed
(
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
num_experts,
) = handle
handle = (
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
None,
num_experts,
)
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
@@ -189,7 +222,7 @@ class DeepEPEngine:
"""
clean_low_latency_buffer
"""
self.deepep_engine.clean_low_latency_buffer(
self.decode_deepep_engine.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
)
@@ -197,7 +230,11 @@ class DeepEPEngine:
"""
barrier_all
"""
self.deepep_engine.barrier_all()
if self.prefill_deepep_engine is not None:
self.prefill_deepep_engine.barrier_all()
if self.decode_deepep_engine is not None:
self.decode_deepep_engine.barrier_all()
class EPRunner:
@@ -210,6 +247,7 @@ class EPRunner:
top_k: int,
hidden: int,
num_experts: int,
splitwise_role: str,
moe_phase: MoEPhase,
num_max_dispatch_tokens_per_rank: int = 1,
ep_size: int = 1,
@@ -223,9 +261,10 @@ class EPRunner:
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
hidden=hidden,
num_experts=num_experts + redundant_experts_num,
moe_phase=moe_phase,
ep_size=ep_size,
ep_rank=ep_rank,
splitwise_role=splitwise_role,
moe_phase=moe_phase,
)
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -286,15 +325,19 @@ class EPPrefillRunner(EPRunner):
top_k: int,
hidden: int,
num_experts: int,
splitwise_role: str,
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
moe_phase: MoEPhase = MoEPhase("prefill"),
):
super().__init__(
top_k,
hidden,
num_experts,
MoEPhase.PREFILL,
splitwise_role,
moe_phase,
num_max_dispatch_tokens_per_rank=256,
ep_size=ep_size,
ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num,
@@ -314,7 +357,7 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_expert,
is_token_in_rank,
_,
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = {
@@ -327,7 +370,7 @@ class EPPrefillRunner(EPRunner):
"topk_idx": topk_idx,
"topk_weights": topk_weights,
}
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
def combine(
self,
@@ -342,7 +385,7 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights,
}
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
return fused_moe_out
@@ -357,16 +400,19 @@ class EPDecoderRunner(EPRunner):
top_k: int,
hidden: int,
num_experts: int,
splitwise_role: str,
num_max_dispatch_tokens_per_rank: int,
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
moe_phase: MoEPhase = MoEPhase("decode"),
):
super().__init__(
top_k,
hidden,
num_experts,
MoEPhase.DECODER,
splitwise_role,
moe_phase,
num_max_dispatch_tokens_per_rank,
ep_size=ep_size,
ep_rank=ep_rank,

View File

@@ -19,8 +19,6 @@ from abc import abstractmethod
import paddle
from paddle import nn
from fastdeploy.config import MoEPhase
from ..quantization.quant_base import QuantMethodBase
@@ -45,25 +43,50 @@ class MoEMethodBase(QuantMethodBase):
Init EP related module
"""
if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase == MoEPhase.DECODER:
from .ep import EPDecoderRunner
if layer.fd_config.parallel_config.splitwise_role == "mixed":
from .ep import EPDecoderRunner, EPPrefillRunner
self.ep_prefill_runner = EPPrefillRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.fd_config.parallel_config.splitwise_role,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
self.ep_decoder_runner = EPDecoderRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.fd_config.parallel_config.splitwise_role,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
else:
if layer.fd_config.parallel_config.moe_phase == "prefill":
from .ep import EPPrefillRunner
self.ep_prefill_runner = EPPrefillRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.fd_config.parallel_config.splitwise_role,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
else:
from .ep import EPDecoderRunner
self.ep_decoder_runner = EPDecoderRunner(
layer.top_k,
layer.hidden_size,
layer.num_experts,
layer.moe_config.num_max_dispatch_tokens_per_rank,
layer.fd_config.parallel_config.splitwise_role,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
@@ -141,7 +164,7 @@ class MoEMethodBase(QuantMethodBase):
Paddle Cutlass compute Fused MoE.
"""
if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase == MoEPhase.PREFILL:
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
return self.apply_ep_prefill(layer, x, gate_out)
else:
return self.apply_ep_decode(layer, x, gate_out)

View File

@@ -794,6 +794,14 @@ class GPUModelRunner(ModelRunnerBase):
# Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
is_decode_batch_list = []
paddle.distributed.all_gather_object(is_decode_batch_list, is_decode_batch)
is_decode_batch = all(is_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if is_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
# Initialzie attention meta data
@@ -1163,16 +1171,18 @@ class GPUModelRunner(ModelRunnerBase):
We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors:
"""
# NOTE(wufeisheng): For Expert Parallelism
if not self.not_need_stop():
self._execute_empty_input()
return None
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop():
self._execute_empty_input()
return None
# 2. Padding inputs for cuda graph
self.padding_cudagraph_inputs()