[NewFeture]add ep rollout model init and update/clear ep buffer (#4039)

* fix gid

* merge

* fix test

* fix bug

* fix

* fix ci
This commit is contained in:
gaoziyuan
2025-09-17 20:24:53 +08:00
committed by GitHub
parent 0d3a57a2c6
commit 896e3bb606
12 changed files with 348 additions and 293 deletions

View File

@@ -20,168 +20,139 @@ import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.platforms import current_platform
try:
from paddle.distributed.communication import deep_ep
except:
logger.warning("import deep_ep Failed!")
if current_platform.is_cuda():
try:
from paddle.distributed.communication import deep_ep
except:
logger.warning("import deep_ep Failed!")
from typing import Optional
import fastdeploy
from fastdeploy.config import MoEPhase
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.utils import singleton
class DeepEPEngineBase:
class DeepEPBufferManager:
_engine: Optional["DeepEPEngine"] = None
@classmethod
def set_engine(cls, engine: "DeepEPEngine"):
cls._engine = engine
@classmethod
def clear_buffer(cls):
if cls._engine:
cls._engine.clear_deep_ep_buffer()
@classmethod
def recreate_buffer(cls):
if cls._engine:
cls._engine.create_deep_ep_buffer()
class DeepEPBuffer:
"""
A wrapper class for DeepEP engine.
Encapsulates DeepEP buffer creation, management and cleanup.
"""
def __init__(
self,
num_max_dispatch_tokens_per_rank: int,
hidden: int,
group,
hidden_size: int,
num_experts: int,
ep_size: int,
ep_rank: int,
num_max_dispatch_tokens_per_rank: int,
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False,
group=None,
):
"""
Initialize the DeepEP engine.
Args:
group: The MPI group object.
ep_size: The number of ranks.
rank_id: The rank id.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
num_experts: The number of experts.
"""
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.hidden = hidden
self.group = group
self.hidden_size = hidden_size
self.num_experts = num_experts
self.ep_size = ep_size
self.rank_id = ep_rank
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.splitwise_role = splitwise_role
self.moe_phase = moe_phase
self.async_finish = async_finish
# TODO(@wufeisheng): Support configurable EP size
if group is None:
group = paddle.distributed.new_group(range(ep_size))
self.group = group
self.num_local_experts = num_experts // ep_size
self.deepep_engine = None
self.init_deepep_engine()
@abstractmethod
def init_deepep_engine(self):
raise NotImplementedError
self.deepep_buffer = None
self.num_nvl_bytes = 0
self.num_rdma_bytes = 0
# Precompute buffer sizes
self._compute_buffer_sizes()
@singleton
class DeepEPEngine(DeepEPEngineBase):
"""
A wrapper class for DeepEP engine.
"""
def _compute_buffer_sizes(self, param_bytes: int = 2):
hidden_bytes = self.hidden_size * param_bytes # bf16 or fp16
def __init__(
self,
num_max_dispatch_tokens_per_rank: int,
hidden: int,
num_experts: int,
ep_size: int,
ep_rank: int,
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False,
group=None,
):
"""
Initialize the DeepEP engine.
Args:
group: The MPI group object.
ep_size: The number of ranks.
rank_id: The rank id.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
num_experts: The number of experts.
"""
super().__init__(
num_max_dispatch_tokens_per_rank,
hidden,
num_experts,
ep_size,
ep_rank,
splitwise_role,
moe_phase,
async_finish,
group,
)
for config in (
deep_ep.Buffer.get_dispatch_config(self.group.world_size),
deep_ep.Buffer.get_combine_config(self.group.world_size),
):
self.num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_nvl_bytes
)
self.num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_rdma_bytes
)
def init_deepep_engine(self):
from paddle.base.core import Config
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.ep_size,
self.num_experts,
)
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
self.ep_config = Config(24, 6, 256)
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
# In mixed EP mode on a single node, we dynamically switch between
# high throughput and low latency modes.
def create_buffer(self):
"""Create or recreate buffer based on role and phase."""
if self.deepep_buffer is not None:
self.clear_buffer()
if self.splitwise_role == "mixed":
self.deepep_engine = deep_ep.Buffer(
logger.info("Initializing mixed mode buffer (low latency).")
self.deepep_buffer = deep_ep.Buffer(
self.group,
int(2e9),
int(6e9),
self.num_nvl_bytes,
self.num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=24,
)
# In disaggregated mode on multiple nodes, we either use
# high throughput mode or low latency mode.
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
else:
if self.moe_phase.phase == "decode":
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
self._create_low_latency_buffer()
elif self.moe_phase.phase == "prefill":
self.deepep_engine = deep_ep.Buffer(
logger.info("Initializing High Throughput Buffer for prefill phase.")
self.deepep_buffer = deep_ep.Buffer(
self.group,
int(5e8),
self.num_nvl_bytes,
0,
low_latency_mode=False,
num_qps_per_rank=1,
)
else:
raise ValueError(f"Unknown generation phase {self.moe_phase}")
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
def get_low_latency_buffer(self):
"""
Get the DeepEP buffer.
Args:
group: The MPI group object.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
hidden: The hidden dimension of the model.
"""
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank`
# (the actual batch size in the decoding engine) should be less than 256
logger.info("DeepEP buffer created successfully.")
def _create_low_latency_buffer(self):
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank,
self.hidden,
self.hidden_size,
self.ep_size,
self.num_experts,
)
# Allocate a buffer if not existed or not enough buffer size
if (
self.deepep_engine is None
or self.deepep_engine.group != self.group
or not self.deepep_engine.low_latency_mode
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
self.deepep_buffer is None
or self.deepep_buffer.group != self.group
or not self.deepep_buffer.low_latency_mode
or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes
):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer(
self.deepep_buffer = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
@@ -189,6 +160,91 @@ class DeepEPEngine(DeepEPEngineBase):
num_qps_per_rank=self.num_experts // self.ep_size,
)
def clear_buffer(self):
"""Clear buffer and free memory."""
if self.deepep_buffer is not None:
del self.deepep_buffer
self.deepep_buffer = None
logger.info("DeepEP buffer cleared.")
def get_buffer(self):
return self.deepep_buffer
def clean_low_latency_buffer(self):
if self.deepep_buffer is not None:
self.deepep_buffer.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
def barrier_all(self):
if self.deepep_buffer is not None:
self.deepep_buffer.barrier_all()
@singleton
class DeepEPEngine:
"""
A wrapper class for DeepEP engine.
Manages buffer lifecycle based on role and phase.
"""
def __init__(
self,
num_max_dispatch_tokens_per_rank: int,
hidden_size: int,
num_experts: int,
ep_size: int,
ep_rank: int,
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False,
group=None,
):
if group is None:
group = paddle.distributed.new_group(range(ep_size))
self.group = group
self.ep_size = ep_size
self.rank_id = ep_rank
self.hidden_size = hidden_size
self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size
self.async_finish = async_finish
from paddle.base.core import Config
self.ep_config = Config(24, 6, 256)
# Store phase and role for buffer management
self._splitwise_role = splitwise_role
self._moe_phase = moe_phase
# Initialize buffer manager
self.buffer = DeepEPBuffer(
group=self.group,
hidden_size=hidden_size,
num_experts=num_experts,
ep_size=ep_size,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
splitwise_role=splitwise_role,
moe_phase=moe_phase,
)
self.buffer.create_buffer()
# Register for global buffer management
DeepEPBufferManager.set_engine(self)
@property
def deepep_engine(self):
"""Backward compatibility alias."""
return self.buffer.get_buffer()
def clear_deep_ep_buffer(self):
self.buffer.clear_buffer()
def create_deep_ep_buffer(self):
self.buffer.create_buffer()
def low_latency_dispatch(
self,
hidden_states: paddle.Tensor,
@@ -196,22 +252,9 @@ class DeepEPEngine(DeepEPEngineBase):
expertwise_scale,
use_fp8: bool = False,
):
"""
Args:
hidden_states: [token_num, hidden] 'bfloat16/int8'
topk_idx: [token_num, num_topk] 'int64'
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
Returns:
recv_hidden_states: [num_local_experts,
num_max_dispatch_tokens_per_rank * ep_size, hidden]
ep_size * num_local_experts = num_experts
recv_count: [num_local_experts]
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
(
packed_recv_x,
recv_expert_count,
@@ -222,7 +265,7 @@ class DeepEPEngine(DeepEPEngineBase):
hidden_states,
topk_idx,
expertwise_scale,
self.num_max_dispatch_tokens_per_rank,
self.buffer.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=use_fp8,
async_finish=False,
@@ -238,27 +281,14 @@ class DeepEPEngine(DeepEPEngineBase):
topk_weights: paddle.Tensor,
handle,
):
"""
Return:
combined_hidden_states: [num_tokens, hidden]
"""
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0": # not develop version of PaddlePaddle
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0":
# TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed
# and when the default recommended version of PaddlePaddle is greater than 3.1.0
(
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
num_experts,
) = handle
handle = (
src_info,
layout_range,
num_max_dispatch_tokens_per_rank,
None,
num_experts,
)
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
handle = (src_info, layout_range, num_max_dispatch_tokens_per_rank, None, num_experts)
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
hidden_states,
@@ -271,18 +301,10 @@ class DeepEPEngine(DeepEPEngineBase):
return combined_hidden_states, combine_hook
def clean_low_latency_buffer(self):
"""
clean_low_latency_buffer
"""
self.deepep_engine.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
)
self.buffer.clean_low_latency_buffer()
def barrier_all(self):
"""
barrier_all
"""
self.deepep_engine.barrier_all()
self.buffer.barrier_all()
class EPRunner:
@@ -293,7 +315,7 @@ class EPRunner:
def __init__(
self,
top_k: int,
hidden: int,
hidden_size: int,
num_experts: int,
splitwise_role: str,
moe_phase: MoEPhase,
@@ -304,33 +326,20 @@ class EPRunner:
ep_group=None,
):
self.top_k = top_k
self.hidden = hidden
self.num_experts = num_experts
self.splitwise_role = splitwise_role
self.moe_phase = moe_phase
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.ep_size = ep_size
self.ep_rank = ep_rank
self.redundant_experts_num = redundant_experts_num
self.ep_group = ep_group
self.init_ep_engine()
def init_ep_engine(self):
self.ep_engine = DeepEPEngine(
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
hidden=self.hidden,
num_experts=self.num_experts + self.redundant_experts_num,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
splitwise_role=self.splitwise_role,
moe_phase=self.moe_phase,
group=self.ep_group,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
hidden_size=hidden_size,
num_experts=num_experts + redundant_experts_num,
ep_size=ep_size,
ep_rank=ep_rank,
splitwise_role=splitwise_role,
moe_phase=moe_phase,
group=ep_group,
)
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
"""
moe_select
"""
if layer.redundant_table_manger is not None:
(
ep_rank_to_expert_id_list,
@@ -346,12 +355,14 @@ class EPRunner:
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
bias=layer.gate_correction_bias,
moe_topk=self.top_k,
apply_norm_weight=True, # apply_norm_weight
apply_norm_weight=True,
enable_softmax_top_k_fused=False,
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
)
else:
if layer.topk_method == "noaux_tc":
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
score, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
@@ -365,28 +376,28 @@ class EPRunner:
gate_out,
layer.gate_correction_bias,
self.top_k,
True, # apply_norm_weight,
True,
False,
)
return topk_idx, topk_weights
@abstractmethod
def dispatch(self, *args, **kwargs):
"""
dispatch
"""
raise NotImplementedError
@abstractmethod
def combine(self, *args, **kwargs):
"""
combine
"""
raise NotImplementedError
def clean_low_latency_buffer(self):
self.ep_engine.clean_low_latency_buffer()
def clear_deep_ep_buffer(self):
self.ep_engine.clear_deep_ep_buffer()
def create_deep_ep_buffer(self):
self.ep_engine.create_deep_ep_buffer()
class EPPrefillRunner(EPRunner):
"""
@@ -396,19 +407,19 @@ class EPPrefillRunner(EPRunner):
def __init__(
self,
top_k: int,
hidden: int,
hidden_size: int,
num_experts: int,
splitwise_role: str,
num_max_dispatch_tokens_per_rank: int,
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
ep_group=None,
moe_phase: MoEPhase = MoEPhase("prefill"),
ep_group=None,
):
super().__init__(
top_k,
hidden,
hidden_size,
num_experts,
splitwise_role,
moe_phase,
@@ -427,6 +438,9 @@ class EPPrefillRunner(EPRunner):
*args,
**kwargs,
):
buffer = self.ep_engine.deepep_engine
if buffer is None:
raise RuntimeError("DeepEP buffer not initialized!")
(
num_tokens_per_rank,
@@ -434,7 +448,7 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_expert,
is_token_in_rank,
_,
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = {
@@ -443,12 +457,12 @@ class EPPrefillRunner(EPRunner):
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": self.ep_engine.ep_config,
"config": self.ep_engine.ep_config, # assuming ep_config still in engine
"async_finish": self.ep_engine.async_finish,
"topk_idx": topk_idx,
"topk_weights": topk_weights,
}
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
return buffer.dispatch(**dispatch_args)
def combine(
self,
@@ -456,6 +470,10 @@ class EPPrefillRunner(EPRunner):
handle: tuple,
recv_topk_weights: paddle.Tensor,
):
buffer = self.ep_engine.deepep_engine
if buffer is None:
raise RuntimeError("DeepEP buffer not initialized!")
combine_args = {
"x": tmp_ffn_out,
"handle": handle,
@@ -463,8 +481,7 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights,
}
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
fused_moe_out, _, _ = buffer.combine(**combine_args)
return fused_moe_out
@@ -476,7 +493,7 @@ class EPDecoderRunner(EPRunner):
def __init__(
self,
top_k: int,
hidden: int,
hidden_size: int,
num_experts: int,
splitwise_role: str,
num_max_dispatch_tokens_per_rank: int,
@@ -488,7 +505,7 @@ class EPDecoderRunner(EPRunner):
):
super().__init__(
top_k,
hidden,
hidden_size,
num_experts,
splitwise_role,
moe_phase,