mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00

* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
545 lines
17 KiB
Python
545 lines
17 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
from abc import abstractmethod
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
from paddleformers.utils.log import logger
|
|
|
|
try:
|
|
from paddle.distributed.communication import deep_ep
|
|
except:
|
|
logger.warning("import deep_ep Failed!")
|
|
|
|
from typing import Optional
|
|
|
|
import fastdeploy
|
|
from fastdeploy.config import MoEPhase
|
|
from fastdeploy.utils import singleton
|
|
|
|
|
|
class DeepEPBufferManager:
|
|
_engine: Optional["DeepEPEngine"] = None
|
|
|
|
@classmethod
|
|
def set_engine(cls, engine: "DeepEPEngine"):
|
|
cls._engine = engine
|
|
|
|
@classmethod
|
|
def clear_buffer(cls):
|
|
if cls._engine:
|
|
cls._engine.clear_deep_ep_buffer()
|
|
|
|
@classmethod
|
|
def recreate_buffer(cls):
|
|
if cls._engine:
|
|
cls._engine.create_deep_ep_buffer()
|
|
|
|
|
|
class DeepEPBuffer:
|
|
"""
|
|
Encapsulates DeepEP buffer creation, management and cleanup.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
group,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
ep_size: int,
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
splitwise_role: str,
|
|
moe_phase: MoEPhase,
|
|
):
|
|
self.group = group
|
|
self.hidden_size = hidden_size
|
|
self.num_experts = num_experts
|
|
self.ep_size = ep_size
|
|
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
|
self.splitwise_role = splitwise_role
|
|
self.moe_phase = moe_phase
|
|
|
|
self.deepep_buffer = None
|
|
self.num_nvl_bytes = 0
|
|
self.num_rdma_bytes = 0
|
|
|
|
# Precompute buffer sizes
|
|
self._compute_buffer_sizes()
|
|
|
|
def _compute_buffer_sizes(self, param_bytes: int = 2):
|
|
hidden_bytes = self.hidden_size * param_bytes # bf16 or fp16
|
|
|
|
for config in (
|
|
deep_ep.Buffer.get_dispatch_config(self.group.world_size),
|
|
deep_ep.Buffer.get_combine_config(self.group.world_size),
|
|
):
|
|
self.num_nvl_bytes = max(
|
|
config.get_nvl_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_nvl_bytes
|
|
)
|
|
self.num_rdma_bytes = max(
|
|
config.get_rdma_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_rdma_bytes
|
|
)
|
|
|
|
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
|
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.hidden_size,
|
|
self.ep_size,
|
|
self.num_experts,
|
|
)
|
|
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
|
|
|
|
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
|
|
|
|
def create_buffer(self):
|
|
"""Create or recreate buffer based on role and phase."""
|
|
if self.deepep_buffer is not None:
|
|
self.clear_buffer()
|
|
|
|
if self.splitwise_role == "mixed":
|
|
logger.info("Initializing mixed mode buffer (low latency).")
|
|
self.deepep_buffer = deep_ep.Buffer(
|
|
self.group,
|
|
self.num_nvl_bytes,
|
|
self.num_rdma_bytes,
|
|
low_latency_mode=True,
|
|
num_qps_per_rank=24,
|
|
)
|
|
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
|
|
else:
|
|
if self.moe_phase.phase == "decode":
|
|
self._create_low_latency_buffer()
|
|
elif self.moe_phase.phase == "prefill":
|
|
logger.info("Initializing High Throughput Buffer for prefill phase.")
|
|
self.deepep_buffer = deep_ep.Buffer(
|
|
self.group,
|
|
self.num_nvl_bytes,
|
|
0,
|
|
low_latency_mode=False,
|
|
num_qps_per_rank=1,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
|
|
|
|
logger.info("DeepEP buffer created successfully.")
|
|
|
|
def _create_low_latency_buffer(self):
|
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.hidden_size,
|
|
self.ep_size,
|
|
self.num_experts,
|
|
)
|
|
|
|
if (
|
|
self.deepep_buffer is None
|
|
or self.deepep_buffer.group != self.group
|
|
or not self.deepep_buffer.low_latency_mode
|
|
or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes
|
|
):
|
|
assert self.num_experts % self.ep_size == 0
|
|
self.deepep_buffer = deep_ep.Buffer(
|
|
self.group,
|
|
0,
|
|
num_rdma_bytes,
|
|
low_latency_mode=True,
|
|
num_qps_per_rank=self.num_experts // self.ep_size,
|
|
)
|
|
|
|
def clear_buffer(self):
|
|
"""Clear buffer and free memory."""
|
|
if self.deepep_buffer is not None:
|
|
del self.deepep_buffer
|
|
self.deepep_buffer = None
|
|
logger.info("DeepEP buffer cleared.")
|
|
|
|
def get_buffer(self):
|
|
return self.deepep_buffer
|
|
|
|
def clean_low_latency_buffer(self):
|
|
if self.deepep_buffer is not None:
|
|
self.deepep_buffer.clean_low_latency_buffer(
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.hidden_size,
|
|
self.num_experts,
|
|
)
|
|
|
|
def barrier_all(self):
|
|
if self.deepep_buffer is not None:
|
|
self.deepep_buffer.barrier_all()
|
|
|
|
|
|
@singleton
|
|
class DeepEPEngine:
|
|
"""
|
|
A wrapper class for DeepEP engine.
|
|
Manages buffer lifecycle based on role and phase.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
ep_size: int,
|
|
ep_rank: int,
|
|
splitwise_role: str,
|
|
moe_phase: MoEPhase,
|
|
async_finish: bool = False,
|
|
group=None,
|
|
):
|
|
if group is None:
|
|
group = paddle.distributed.new_group(range(ep_size))
|
|
self.group = group
|
|
self.ep_size = ep_size
|
|
self.rank_id = ep_rank
|
|
self.hidden_size = hidden_size
|
|
self.num_experts = num_experts
|
|
self.num_local_experts = num_experts // ep_size
|
|
self.async_finish = async_finish
|
|
|
|
self.ep_config = None
|
|
|
|
# Store phase and role for buffer management
|
|
self._splitwise_role = splitwise_role
|
|
self._moe_phase = moe_phase
|
|
|
|
# Initialize buffer manager
|
|
self.buffer = DeepEPBuffer(
|
|
group=self.group,
|
|
hidden_size=hidden_size,
|
|
num_experts=num_experts,
|
|
ep_size=ep_size,
|
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
|
splitwise_role=splitwise_role,
|
|
moe_phase=moe_phase,
|
|
)
|
|
self.buffer.create_buffer()
|
|
|
|
# Register for global buffer management
|
|
DeepEPBufferManager.set_engine(self)
|
|
|
|
@property
|
|
def deepep_engine(self):
|
|
"""Backward compatibility alias."""
|
|
return self.buffer.get_buffer()
|
|
|
|
def clear_deep_ep_buffer(self):
|
|
self.buffer.clear_buffer()
|
|
|
|
def create_deep_ep_buffer(self):
|
|
self.buffer.create_buffer()
|
|
|
|
def low_latency_dispatch(
|
|
self,
|
|
hidden_states: paddle.Tensor,
|
|
topk_idx: paddle.Tensor,
|
|
expertwise_scale,
|
|
use_fp8: bool = False,
|
|
):
|
|
if self.deepep_engine is None:
|
|
raise RuntimeError("DeepEP buffer not initialized!")
|
|
|
|
(
|
|
packed_recv_x,
|
|
recv_expert_count,
|
|
handle,
|
|
_,
|
|
dispatch_hook,
|
|
) = self.deepep_engine.low_latency_dispatch(
|
|
hidden_states,
|
|
topk_idx,
|
|
expertwise_scale,
|
|
self.buffer.num_max_dispatch_tokens_per_rank,
|
|
self.num_experts,
|
|
use_fp8=use_fp8,
|
|
async_finish=False,
|
|
return_recv_hook=True,
|
|
)
|
|
|
|
return packed_recv_x, recv_expert_count, handle, dispatch_hook
|
|
|
|
def low_latency_combine(
|
|
self,
|
|
hidden_states: paddle.Tensor,
|
|
topk_idx: paddle.Tensor,
|
|
topk_weights: paddle.Tensor,
|
|
handle,
|
|
):
|
|
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0":
|
|
# TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed
|
|
# and when the default recommended version of PaddlePaddle is greater than 3.1.0
|
|
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
|
handle = (src_info, layout_range, num_max_dispatch_tokens_per_rank, None, num_experts)
|
|
|
|
if self.deepep_engine is None:
|
|
raise RuntimeError("DeepEP buffer not initialized!")
|
|
|
|
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
handle,
|
|
async_finish=False,
|
|
return_recv_hook=True,
|
|
)
|
|
return combined_hidden_states, combine_hook
|
|
|
|
def clean_low_latency_buffer(self):
|
|
self.buffer.clean_low_latency_buffer()
|
|
|
|
def barrier_all(self):
|
|
self.buffer.barrier_all()
|
|
|
|
|
|
class EPRunner:
|
|
"""
|
|
EPRunnerBase
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
splitwise_role: str,
|
|
moe_phase: MoEPhase,
|
|
num_max_dispatch_tokens_per_rank: int = 1,
|
|
ep_size: int = 1,
|
|
ep_rank: int = 0,
|
|
redundant_experts_num: int = 0,
|
|
ep_group=None,
|
|
):
|
|
self.top_k = top_k
|
|
self.num_experts = num_experts
|
|
self.redundant_experts_num = redundant_experts_num
|
|
self.ep_engine = DeepEPEngine(
|
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
|
hidden_size=hidden_size,
|
|
num_experts=num_experts + redundant_experts_num,
|
|
ep_size=ep_size,
|
|
ep_rank=ep_rank,
|
|
splitwise_role=splitwise_role,
|
|
moe_phase=moe_phase,
|
|
group=ep_group,
|
|
)
|
|
|
|
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
|
if layer.redundant_table_manger is not None:
|
|
(
|
|
ep_rank_to_expert_id_list,
|
|
expert_id_to_ep_rank_array,
|
|
expert_in_rank_num_list,
|
|
tokens_per_expert_stats_list,
|
|
) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx)
|
|
|
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
|
|
gating_logits=gate_out,
|
|
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
|
|
expert_in_rank_num_list=expert_in_rank_num_list,
|
|
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
|
bias=layer.gate_correction_bias,
|
|
moe_topk=self.top_k,
|
|
apply_norm_weight=True,
|
|
enable_softmax_top_k_fused=False,
|
|
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
|
)
|
|
else:
|
|
if layer.topk_method == "noaux_tc":
|
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
|
|
|
score, topk_weights, topk_idx = get_moe_scores(
|
|
gate_out,
|
|
layer.n_group,
|
|
layer.topk_group,
|
|
layer.top_k,
|
|
layer.routed_scaling_factor,
|
|
layer.gate_correction_bias,
|
|
)
|
|
else:
|
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
|
gate_out,
|
|
layer.gate_correction_bias,
|
|
self.top_k,
|
|
True,
|
|
False,
|
|
)
|
|
return topk_idx, topk_weights
|
|
|
|
@abstractmethod
|
|
def dispatch(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def combine(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def clean_low_latency_buffer(self):
|
|
self.ep_engine.clean_low_latency_buffer()
|
|
|
|
def clear_deep_ep_buffer(self):
|
|
self.ep_engine.clear_deep_ep_buffer()
|
|
|
|
def create_deep_ep_buffer(self):
|
|
self.ep_engine.create_deep_ep_buffer()
|
|
|
|
|
|
class EPPrefillRunner(EPRunner):
|
|
"""
|
|
EPPrefillRunner
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
splitwise_role: str,
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
ep_size: int = 1,
|
|
ep_rank: int = 0,
|
|
redundant_experts_num: int = 0,
|
|
moe_phase: MoEPhase = MoEPhase("prefill"),
|
|
ep_group=None,
|
|
):
|
|
super().__init__(
|
|
top_k,
|
|
hidden_size,
|
|
num_experts,
|
|
splitwise_role,
|
|
moe_phase,
|
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
|
ep_size=ep_size,
|
|
ep_rank=ep_rank,
|
|
redundant_experts_num=redundant_experts_num,
|
|
ep_group=ep_group,
|
|
)
|
|
|
|
def dispatch(
|
|
self,
|
|
x: paddle.Tensor,
|
|
topk_idx: paddle.Tensor,
|
|
topk_weights: paddle.Tensor,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
buffer = self.ep_engine.deepep_engine
|
|
if buffer is None:
|
|
raise RuntimeError("DeepEP buffer not initialized!")
|
|
|
|
(
|
|
num_tokens_per_rank,
|
|
num_tokens_per_rdma_rank,
|
|
num_tokens_per_expert,
|
|
is_token_in_rank,
|
|
_,
|
|
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
|
|
|
|
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
|
dispatch_args = {
|
|
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
|
|
"num_tokens_per_rank": num_tokens_per_rank,
|
|
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
|
"is_token_in_rank": is_token_in_rank,
|
|
"num_tokens_per_expert": num_tokens_per_expert,
|
|
"config": self.ep_engine.ep_config, # assuming ep_config still in engine
|
|
"async_finish": self.ep_engine.async_finish,
|
|
"topk_idx": topk_idx,
|
|
"topk_weights": topk_weights,
|
|
}
|
|
return buffer.dispatch(**dispatch_args)
|
|
|
|
def combine(
|
|
self,
|
|
tmp_ffn_out: paddle.Tensor,
|
|
handle: tuple,
|
|
recv_topk_weights: paddle.Tensor,
|
|
):
|
|
buffer = self.ep_engine.deepep_engine
|
|
if buffer is None:
|
|
raise RuntimeError("DeepEP buffer not initialized!")
|
|
|
|
combine_args = {
|
|
"x": tmp_ffn_out,
|
|
"handle": handle,
|
|
"config": self.ep_engine.ep_config,
|
|
"async_finish": self.ep_engine.async_finish,
|
|
"topk_weights": recv_topk_weights,
|
|
}
|
|
fused_moe_out, _, _ = buffer.combine(**combine_args)
|
|
return fused_moe_out
|
|
|
|
|
|
class EPDecoderRunner(EPRunner):
|
|
"""
|
|
EPDecoderRunner
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
splitwise_role: str,
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
ep_size: int = 1,
|
|
ep_rank: int = 0,
|
|
redundant_experts_num: int = 0,
|
|
ep_group=None,
|
|
moe_phase: MoEPhase = MoEPhase("decode"),
|
|
):
|
|
super().__init__(
|
|
top_k,
|
|
hidden_size,
|
|
num_experts,
|
|
splitwise_role,
|
|
moe_phase,
|
|
num_max_dispatch_tokens_per_rank,
|
|
ep_size=ep_size,
|
|
ep_rank=ep_rank,
|
|
redundant_experts_num=redundant_experts_num,
|
|
ep_group=ep_group,
|
|
)
|
|
|
|
def dispatch(
|
|
self,
|
|
x: paddle.Tensor,
|
|
topk_idx: paddle.Tensor,
|
|
topk_weights: paddle.Tensor,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
expertwise_scale = kwargs.get("expertwise_scale", None)
|
|
use_fp8 = kwargs.get("use_fp8", False)
|
|
|
|
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
|
|
x, topk_idx, expertwise_scale, use_fp8
|
|
)
|
|
if dispatch_hook is not None:
|
|
dispatch_hook()
|
|
|
|
return recv_hidden_states, recv_expert_count, handle
|
|
|
|
def combine(self, ffn_out, topk_idx, topk_weights, handle):
|
|
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
|
|
ffn_out, topk_idx, topk_weights, handle
|
|
)
|
|
if combine_hook is not None:
|
|
combine_hook()
|
|
|
|
return combined_hidden_states
|