mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
[NewFeture]add ep rollout model init and update/clear ep buffer (#3927)
* add ep rollout model init && add deep update/clear * fix test
This commit is contained in:
@@ -25,52 +25,175 @@ try:
|
||||
except:
|
||||
logger.warning("import deep_ep Failed!")
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
class DeepEPBufferManager:
|
||||
_engine: Optional["DeepEPEngine"] = None
|
||||
|
||||
@classmethod
|
||||
def set_engine(cls, engine: "DeepEPEngine"):
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def clear_buffer(cls):
|
||||
if cls._engine:
|
||||
cls._engine.clear_deep_ep_buffer()
|
||||
|
||||
@classmethod
|
||||
def recreate_buffer(cls):
|
||||
if cls._engine:
|
||||
cls._engine.create_deep_ep_buffer()
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
class DeepEPBuffer:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
Encapsulates DeepEP buffer creation, management and cleanup.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
):
|
||||
self.group = group
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.ep_size = ep_size
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
self.splitwise_role = splitwise_role
|
||||
self.moe_phase = moe_phase
|
||||
|
||||
self.deepep_buffer = None
|
||||
self.num_nvl_bytes = 0
|
||||
self.num_rdma_bytes = 0
|
||||
|
||||
# Precompute buffer sizes
|
||||
self._compute_buffer_sizes()
|
||||
|
||||
def _compute_buffer_sizes(self, param_bytes: int = 2):
|
||||
hidden_bytes = self.hidden_size * param_bytes # bf16 or fp16
|
||||
|
||||
for config in (
|
||||
deep_ep.Buffer.get_dispatch_config(self.group.world_size),
|
||||
deep_ep.Buffer.get_combine_config(self.group.world_size),
|
||||
):
|
||||
self.num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_nvl_bytes
|
||||
)
|
||||
self.num_rdma_bytes = max(
|
||||
config.get_rdma_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_rdma_bytes
|
||||
)
|
||||
|
||||
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
|
||||
|
||||
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
|
||||
|
||||
def create_buffer(self):
|
||||
"""Create or recreate buffer based on role and phase."""
|
||||
if self.deepep_buffer is not None:
|
||||
self.clear_buffer()
|
||||
|
||||
if self.splitwise_role == "mixed":
|
||||
logger.info("Initializing mixed mode buffer (low latency).")
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
self.num_nvl_bytes,
|
||||
self.num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=24,
|
||||
)
|
||||
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
|
||||
else:
|
||||
if self.moe_phase.phase == "decode":
|
||||
self._create_low_latency_buffer()
|
||||
elif self.moe_phase.phase == "prefill":
|
||||
logger.info("Initializing High Throughput Buffer for prefill phase.")
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
self.num_nvl_bytes,
|
||||
0,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
|
||||
|
||||
logger.info("DeepEP buffer created successfully.")
|
||||
|
||||
def _create_low_latency_buffer(self):
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
if (
|
||||
self.deepep_buffer is None
|
||||
or self.deepep_buffer.group != self.group
|
||||
or not self.deepep_buffer.low_latency_mode
|
||||
or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=self.num_experts // self.ep_size,
|
||||
)
|
||||
|
||||
def clear_buffer(self):
|
||||
"""Clear buffer and free memory."""
|
||||
if self.deepep_buffer is not None:
|
||||
del self.deepep_buffer
|
||||
self.deepep_buffer = None
|
||||
logger.info("DeepEP buffer cleared.")
|
||||
|
||||
def get_buffer(self):
|
||||
return self.deepep_buffer
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
if self.deepep_buffer is not None:
|
||||
self.deepep_buffer.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
def barrier_all(self):
|
||||
if self.deepep_buffer is not None:
|
||||
self.deepep_buffer.barrier_all()
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
Manages buffer lifecycle based on role and phase.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
@@ -79,95 +202,48 @@ class DeepEPEngine:
|
||||
async_finish: bool = False,
|
||||
group=None,
|
||||
):
|
||||
"""
|
||||
Initialize the DeepEP engine.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
ep_size: The number of ranks.
|
||||
rank_id: The rank id.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
num_experts: The number of experts.
|
||||
"""
|
||||
# TODO(@wufeisheng): Support configurable EP size
|
||||
if group is None:
|
||||
group = paddle.distributed.new_group(range(ep_size))
|
||||
self.group = group
|
||||
self.ep_size = ep_size
|
||||
self.rank_id = ep_rank
|
||||
self.hidden = hidden
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.async_finish = async_finish
|
||||
|
||||
self.deepep_engine = None
|
||||
|
||||
from paddle.base.core import Config
|
||||
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
|
||||
# In mixed EP mode on a single node, we dynamically switch between
|
||||
# high throughput and low latency modes.
|
||||
# Store phase and role for buffer management
|
||||
self._splitwise_role = splitwise_role
|
||||
self._moe_phase = moe_phase
|
||||
|
||||
if splitwise_role == "mixed":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(2e9),
|
||||
int(6e9),
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=24,
|
||||
)
|
||||
# In disaggregated mode on mutiple nodes, we either use
|
||||
# high throughput mode or low latency mode.
|
||||
else:
|
||||
if moe_phase.phase == "decode":
|
||||
logger.info("Initializing Low Latency Buffer")
|
||||
self.get_low_latency_buffer()
|
||||
elif moe_phase.phase == "prefill":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(5e8),
|
||||
0,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase {moe_phase}")
|
||||
|
||||
def get_low_latency_buffer(self):
|
||||
"""
|
||||
Get the DeepEP buffer.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
"""
|
||||
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||
# So we recommend that `num_max_dispatch_tokens_per_rank`
|
||||
# (the actual batch size in the decoding engine) should be less than 256
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
# Initialize buffer manager
|
||||
self.buffer = DeepEPBuffer(
|
||||
group=self.group,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
splitwise_role=splitwise_role,
|
||||
moe_phase=moe_phase,
|
||||
)
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if (
|
||||
self.deepep_engine is None
|
||||
or self.deepep_engine.group != self.group
|
||||
or not self.deepep_engine.low_latency_mode
|
||||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=self.num_experts // self.ep_size,
|
||||
)
|
||||
self.buffer.create_buffer()
|
||||
|
||||
# Register for global buffer management
|
||||
DeepEPBufferManager.set_engine(self)
|
||||
|
||||
@property
|
||||
def deepep_engine(self):
|
||||
"""Backward compatibility alias."""
|
||||
return self.buffer.get_buffer()
|
||||
|
||||
def clear_deep_ep_buffer(self):
|
||||
self.buffer.clear_buffer()
|
||||
|
||||
def create_deep_ep_buffer(self):
|
||||
self.buffer.create_buffer()
|
||||
|
||||
def low_latency_dispatch(
|
||||
self,
|
||||
@@ -176,22 +252,9 @@ class DeepEPEngine:
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [token_num, hidden] 'bfloat16/int8'
|
||||
topk_idx: [token_num, num_topk] 'int64'
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
Returns:
|
||||
recv_hidden_states: [num_local_experts,
|
||||
num_max_dispatch_tokens_per_rank * ep_size, hidden]
|
||||
ep_size * num_local_experts = num_experts
|
||||
recv_count: [num_local_experts]
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
(
|
||||
packed_recv_x,
|
||||
recv_expert_count,
|
||||
@@ -202,7 +265,7 @@ class DeepEPEngine:
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
expertwise_scale,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.buffer.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
@@ -218,27 +281,14 @@ class DeepEPEngine:
|
||||
topk_weights: paddle.Tensor,
|
||||
handle,
|
||||
):
|
||||
"""
|
||||
|
||||
Return:
|
||||
combined_hidden_states: [num_tokens, hidden]
|
||||
"""
|
||||
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0": # not develop version of PaddlePaddle
|
||||
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0":
|
||||
# TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed
|
||||
# and when the default recommended version of PaddlePaddle is greater than 3.1.0
|
||||
(
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
) = handle
|
||||
handle = (
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
None,
|
||||
num_experts,
|
||||
)
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
handle = (src_info, layout_range, num_max_dispatch_tokens_per_rank, None, num_experts)
|
||||
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
||||
hidden_states,
|
||||
@@ -251,18 +301,10 @@ class DeepEPEngine:
|
||||
return combined_hidden_states, combine_hook
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
"""
|
||||
clean_low_latency_buffer
|
||||
"""
|
||||
self.deepep_engine.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
||||
)
|
||||
self.buffer.clean_low_latency_buffer()
|
||||
|
||||
def barrier_all(self):
|
||||
"""
|
||||
barrier_all
|
||||
"""
|
||||
self.deepep_engine.barrier_all()
|
||||
self.buffer.barrier_all()
|
||||
|
||||
|
||||
class EPRunner:
|
||||
@@ -273,7 +315,7 @@ class EPRunner:
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
@@ -288,7 +330,7 @@ class EPRunner:
|
||||
self.redundant_experts_num = redundant_experts_num
|
||||
self.ep_engine = DeepEPEngine(
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
hidden=hidden,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts + redundant_experts_num,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
@@ -298,9 +340,6 @@ class EPRunner:
|
||||
)
|
||||
|
||||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
||||
"""
|
||||
moe_select
|
||||
"""
|
||||
if layer.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
@@ -316,12 +355,14 @@ class EPRunner:
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
bias=layer.gate_correction_bias,
|
||||
moe_topk=self.top_k,
|
||||
apply_norm_weight=True, # apply_norm_weight
|
||||
apply_norm_weight=True,
|
||||
enable_softmax_top_k_fused=False,
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
else:
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
score, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
@@ -335,28 +376,28 @@ class EPRunner:
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
self.top_k,
|
||||
True, # apply_norm_weight,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
return topk_idx, topk_weights
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, *args, **kwargs):
|
||||
"""
|
||||
dispatch
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def combine(self, *args, **kwargs):
|
||||
"""
|
||||
combine
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
self.ep_engine.clean_low_latency_buffer()
|
||||
|
||||
def clear_deep_ep_buffer(self):
|
||||
self.ep_engine.clear_deep_ep_buffer()
|
||||
|
||||
def create_deep_ep_buffer(self):
|
||||
self.ep_engine.create_deep_ep_buffer()
|
||||
|
||||
|
||||
class EPPrefillRunner(EPRunner):
|
||||
"""
|
||||
@@ -366,7 +407,7 @@ class EPPrefillRunner(EPRunner):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
@@ -378,7 +419,7 @@ class EPPrefillRunner(EPRunner):
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
@@ -397,6 +438,9 @@ class EPPrefillRunner(EPRunner):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
buffer = self.ep_engine.deepep_engine
|
||||
if buffer is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
@@ -404,7 +448,7 @@ class EPPrefillRunner(EPRunner):
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
|
||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||
dispatch_args = {
|
||||
@@ -413,12 +457,12 @@ class EPPrefillRunner(EPRunner):
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": self.ep_engine.ep_config,
|
||||
"config": self.ep_engine.ep_config, # assuming ep_config still in engine
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": topk_weights,
|
||||
}
|
||||
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
|
||||
return buffer.dispatch(**dispatch_args)
|
||||
|
||||
def combine(
|
||||
self,
|
||||
@@ -426,6 +470,10 @@ class EPPrefillRunner(EPRunner):
|
||||
handle: tuple,
|
||||
recv_topk_weights: paddle.Tensor,
|
||||
):
|
||||
buffer = self.ep_engine.deepep_engine
|
||||
if buffer is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combine_args = {
|
||||
"x": tmp_ffn_out,
|
||||
"handle": handle,
|
||||
@@ -433,8 +481,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_weights": recv_topk_weights,
|
||||
}
|
||||
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
|
||||
|
||||
fused_moe_out, _, _ = buffer.combine(**combine_args)
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
@@ -446,7 +493,7 @@ class EPDecoderRunner(EPRunner):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
@@ -458,7 +505,7 @@ class EPDecoderRunner(EPRunner):
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
|
@@ -40,66 +40,52 @@ class MoEMethodBase(QuantMethodBase):
|
||||
"down_proj_weight_scale",
|
||||
]
|
||||
self.pack_num = 1
|
||||
self.ep_prefill_runner = None
|
||||
self.ep_decoder_runner = None
|
||||
|
||||
def init_ep(self, layer: nn.Layer) -> None:
|
||||
"""
|
||||
Init EP related module
|
||||
Initialize EP (Expert Parallel) related modules.
|
||||
"""
|
||||
if layer.ep_size > 1:
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||
if layer.ep_size <= 1:
|
||||
return
|
||||
|
||||
self.ep_prefill_runner = EPPrefillRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
self.ep_decoder_runner = EPDecoderRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
# Lazy import to avoid circular dependency or unnecessary loading
|
||||
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||
|
||||
# Common arguments for both runners
|
||||
common_args = {
|
||||
"top_k": layer.top_k,
|
||||
"hidden_size": layer.hidden_size,
|
||||
"num_experts": layer.num_experts,
|
||||
"splitwise_role": layer.fd_config.parallel_config.splitwise_role,
|
||||
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
"ep_size": layer.ep_size,
|
||||
"ep_rank": layer.ep_rank,
|
||||
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
|
||||
"ep_group": layer.fd_config.parallel_config.ep_group,
|
||||
}
|
||||
|
||||
config = layer.fd_config
|
||||
splitwise_role = config.parallel_config.splitwise_role
|
||||
load_strategy = config.load_config.load_strategy
|
||||
|
||||
# For "mixed" splitwise role: conditionally initialize both or none
|
||||
if splitwise_role == "mixed":
|
||||
if load_strategy == "meta":
|
||||
# for RL init model without deepep buff
|
||||
return
|
||||
else:
|
||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||
from .ep import EPPrefillRunner
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
self.ep_decoder_runner = EPDecoderRunner(**common_args)
|
||||
return
|
||||
|
||||
self.ep_prefill_runner = EPPrefillRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
else:
|
||||
from .ep import EPDecoderRunner
|
||||
|
||||
self.ep_decoder_runner = EPDecoderRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
ep_group=layer.fd_config.parallel_config.ep_group,
|
||||
)
|
||||
# For non-mixed ep
|
||||
phase = config.parallel_config.moe_phase.phase
|
||||
if phase == "prefill":
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
else:
|
||||
self.ep_decoder_runner = EPDecoderRunner(**common_args)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
@@ -180,7 +166,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
else:
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_decode(layer, x, gate)
|
||||
return self.apply_ep_prefill(layer, x, gate)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate)
|
||||
|
||||
|
@@ -27,11 +27,7 @@ from ..utils import get_tensor
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
noaux_tc,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
@@ -46,31 +42,6 @@ elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
@@ -154,7 +125,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
# 3. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
@@ -255,6 +226,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -319,7 +319,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .ep import get_moe_scores
|
||||
from .moe import get_moe_scores
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
|
@@ -21,37 +21,12 @@ import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
MoeWna16MarlinGemmApi,
|
||||
noaux_tc,
|
||||
tritonmoe_preprocess_func,
|
||||
)
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
b_q_weight: paddle.Tensor,
|
||||
perm: paddle.Tensor,
|
||||
@@ -279,6 +254,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
topk_method = layer.topk_method
|
||||
|
||||
if topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -24,7 +24,7 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
from .ep import get_moe_scores
|
||||
from .moe import get_moe_scores
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
|
@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
@@ -54,6 +59,31 @@ def get_moe_method():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
"""
|
||||
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
||||
|
Reference in New Issue
Block a user