mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feat] support mixed ep (#2969)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* Support mixed ep * fix comment * fix comment * update mixep * fix conflict * fix typo * update * fix typo * fix code style * fix conflict
This commit is contained in:
@@ -43,9 +43,10 @@ class DeepEPEngine:
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
moe_phase: MoEPhase,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -65,26 +66,44 @@ class DeepEPEngine:
|
||||
self.hidden = hidden
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.moe_phase = moe_phase
|
||||
self.async_finish = async_finish
|
||||
|
||||
self.deepep_engine = None
|
||||
self.prefill_deepep_engine = None
|
||||
self.decode_deepep_engine = None
|
||||
|
||||
if moe_phase == MoEPhase.DECODER:
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
|
||||
# In mixed EP mode on a single node, we dynamically switch between
|
||||
# high throughput and low latency modes.
|
||||
if splitwise_role == "mixed":
|
||||
# decode engine
|
||||
logger.info("Initializing Low Latency Buffer")
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
self.get_low_latency_buffer()
|
||||
elif moe_phase == MoEPhase.PREFILL:
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
# prefill engine
|
||||
self.prefill_deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(5e8),
|
||||
0,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
# In disaggregated mode on mutiple nodes, we either use
|
||||
# high throughput mode or low latency mode.
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase {moe_phase}")
|
||||
if moe_phase.phase == "decode":
|
||||
logger.info("Initializing Low Latency Buffer")
|
||||
self.get_low_latency_buffer()
|
||||
elif moe_phase.phase == "prefill":
|
||||
self.prefill_deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(5e8),
|
||||
0,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase {moe_phase}")
|
||||
|
||||
def get_low_latency_buffer(self):
|
||||
"""
|
||||
@@ -105,14 +124,14 @@ class DeepEPEngine:
|
||||
)
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if (
|
||||
self.deepep_engine is None
|
||||
or self.deepep_engine.group != self.group
|
||||
or not self.deepep_engine.low_latency_mode
|
||||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||
self.decode_deepep_engine is None
|
||||
or self.decode_deepep_engine.group != self.group
|
||||
or not self.decode_deepep_engine.low_latency_mode
|
||||
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.decode_deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
@@ -149,7 +168,7 @@ class DeepEPEngine:
|
||||
handle,
|
||||
_,
|
||||
dispatch_hook,
|
||||
) = self.deepep_engine.low_latency_dispatch(
|
||||
) = self.decode_deepep_engine.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
expertwise_scale,
|
||||
@@ -174,8 +193,22 @@ class DeepEPEngine:
|
||||
Return:
|
||||
combined_hidden_states: [num_tokens, hidden]
|
||||
"""
|
||||
# TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed
|
||||
(
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
) = handle
|
||||
handle = (
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
None,
|
||||
num_experts,
|
||||
)
|
||||
|
||||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
||||
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
@@ -189,7 +222,7 @@ class DeepEPEngine:
|
||||
"""
|
||||
clean_low_latency_buffer
|
||||
"""
|
||||
self.deepep_engine.clean_low_latency_buffer(
|
||||
self.decode_deepep_engine.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
||||
)
|
||||
|
||||
@@ -197,7 +230,11 @@ class DeepEPEngine:
|
||||
"""
|
||||
barrier_all
|
||||
"""
|
||||
self.deepep_engine.barrier_all()
|
||||
if self.prefill_deepep_engine is not None:
|
||||
self.prefill_deepep_engine.barrier_all()
|
||||
|
||||
if self.decode_deepep_engine is not None:
|
||||
self.decode_deepep_engine.barrier_all()
|
||||
|
||||
|
||||
class EPRunner:
|
||||
@@ -210,6 +247,7 @@ class EPRunner:
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
num_max_dispatch_tokens_per_rank: int = 1,
|
||||
ep_size: int = 1,
|
||||
@@ -223,9 +261,10 @@ class EPRunner:
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
hidden=hidden,
|
||||
num_experts=num_experts + redundant_experts_num,
|
||||
moe_phase=moe_phase,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
splitwise_role=splitwise_role,
|
||||
moe_phase=moe_phase,
|
||||
)
|
||||
|
||||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
||||
@@ -286,15 +325,19 @@ class EPPrefillRunner(EPRunner):
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
redundant_experts_num: int = 0,
|
||||
moe_phase: MoEPhase = MoEPhase("prefill"),
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
num_experts,
|
||||
MoEPhase.PREFILL,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
num_max_dispatch_tokens_per_rank=256,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
redundant_experts_num=redundant_experts_num,
|
||||
@@ -314,7 +357,7 @@ class EPPrefillRunner(EPRunner):
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
|
||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||
dispatch_args = {
|
||||
@@ -327,7 +370,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": topk_weights,
|
||||
}
|
||||
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
|
||||
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
|
||||
|
||||
def combine(
|
||||
self,
|
||||
@@ -342,7 +385,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_weights": recv_topk_weights,
|
||||
}
|
||||
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
|
||||
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
@@ -357,16 +400,19 @@ class EPDecoderRunner(EPRunner):
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
redundant_experts_num: int = 0,
|
||||
moe_phase: MoEPhase = MoEPhase("decode"),
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
num_experts,
|
||||
MoEPhase.DECODER,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
|
Reference in New Issue
Block a user