[EP] Refactor DeepEP Engine Organization for Mixed Mode & Buffer Management Optimization (#3182)

* Add support for mixed-ep across multi nodes

* code refine

---------

Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
RichardWooSJTU
2025-08-05 15:40:11 +08:00
committed by GitHub
parent 14ed75f7d3
commit f5c64a074c
2 changed files with 29 additions and 30 deletions

View File

@@ -68,8 +68,7 @@ class DeepEPEngine:
self.num_local_experts = num_experts // ep_size self.num_local_experts = num_experts // ep_size
self.async_finish = async_finish self.async_finish = async_finish
self.prefill_deepep_engine = None self.deepep_engine = None
self.decode_deepep_engine = None
self.ep_config = Config(24, 6, 256) self.ep_config = Config(24, 6, 256)
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
@@ -77,16 +76,12 @@ class DeepEPEngine:
# In mixed EP mode on a single node, we dynamically switch between # In mixed EP mode on a single node, we dynamically switch between
# high throughput and low latency modes. # high throughput and low latency modes.
if splitwise_role == "mixed": if splitwise_role == "mixed":
# decode engine self.deepep_engine = deep_ep.Buffer(
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
# prefill engine
self.prefill_deepep_engine = deep_ep.Buffer(
self.group, self.group,
int(5e8), int(2e9),
0, int(5e9),
low_latency_mode=False, low_latency_mode=True,
num_qps_per_rank=1, num_qps_per_rank=24,
) )
# In disaggregated mode on mutiple nodes, we either use # In disaggregated mode on mutiple nodes, we either use
# high throughput mode or low latency mode. # high throughput mode or low latency mode.
@@ -95,7 +90,7 @@ class DeepEPEngine:
logger.info("Initializing Low Latency Buffer") logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer() self.get_low_latency_buffer()
elif moe_phase.phase == "prefill": elif moe_phase.phase == "prefill":
self.prefill_deepep_engine = deep_ep.Buffer( self.deepep_engine = deep_ep.Buffer(
self.group, self.group,
int(5e8), int(5e8),
0, 0,
@@ -124,14 +119,14 @@ class DeepEPEngine:
) )
# Allocate a buffer if not existed or not enough buffer size # Allocate a buffer if not existed or not enough buffer size
if ( if (
self.decode_deepep_engine is None self.deepep_engine is None
or self.decode_deepep_engine.group != self.group or self.deepep_engine.group != self.group
or not self.decode_deepep_engine.low_latency_mode or not self.deepep_engine.low_latency_mode
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
): ):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts # NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0 assert self.num_experts % self.ep_size == 0
self.decode_deepep_engine = deep_ep.Buffer( self.deepep_engine = deep_ep.Buffer(
self.group, self.group,
0, 0,
num_rdma_bytes, num_rdma_bytes,
@@ -168,7 +163,7 @@ class DeepEPEngine:
handle, handle,
_, _,
dispatch_hook, dispatch_hook,
) = self.decode_deepep_engine.low_latency_dispatch( ) = self.deepep_engine.low_latency_dispatch(
hidden_states, hidden_states,
topk_idx, topk_idx,
expertwise_scale, expertwise_scale,
@@ -210,7 +205,7 @@ class DeepEPEngine:
num_experts, num_experts,
) )
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine( combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
@@ -224,7 +219,7 @@ class DeepEPEngine:
""" """
clean_low_latency_buffer clean_low_latency_buffer
""" """
self.decode_deepep_engine.clean_low_latency_buffer( self.deepep_engine.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
) )
@@ -232,11 +227,7 @@ class DeepEPEngine:
""" """
barrier_all barrier_all
""" """
if self.prefill_deepep_engine is not None: self.deepep_engine.barrier_all()
self.prefill_deepep_engine.barrier_all()
if self.decode_deepep_engine is not None:
self.decode_deepep_engine.barrier_all()
class EPRunner: class EPRunner:
@@ -316,6 +307,9 @@ class EPRunner:
""" """
raise NotImplementedError raise NotImplementedError
def clean_low_latency_buffer(self):
self.ep_engine.clean_low_latency_buffer()
class EPPrefillRunner(EPRunner): class EPPrefillRunner(EPRunner):
""" """
@@ -328,6 +322,7 @@ class EPPrefillRunner(EPRunner):
hidden: int, hidden: int,
num_experts: int, num_experts: int,
splitwise_role: str, splitwise_role: str,
num_max_dispatch_tokens_per_rank: int,
ep_size: int = 1, ep_size: int = 1,
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
@@ -339,7 +334,7 @@ class EPPrefillRunner(EPRunner):
num_experts, num_experts,
splitwise_role, splitwise_role,
moe_phase, moe_phase,
num_max_dispatch_tokens_per_rank=256, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num, redundant_experts_num=redundant_experts_num,
@@ -359,7 +354,7 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_expert, num_tokens_per_expert,
is_token_in_rank, is_token_in_rank,
_, _,
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) ) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
x_scale_tensor = kwargs.get("x_scale_tensor", None) x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = { dispatch_args = {
@@ -372,7 +367,7 @@ class EPPrefillRunner(EPRunner):
"topk_idx": topk_idx, "topk_idx": topk_idx,
"topk_weights": topk_weights, "topk_weights": topk_weights,
} }
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args) return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
def combine( def combine(
self, self,
@@ -387,14 +382,14 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish, "async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights, "topk_weights": recv_topk_weights,
} }
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args) fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
return fused_moe_out return fused_moe_out
class EPDecoderRunner(EPRunner): class EPDecoderRunner(EPRunner):
""" """
EPPrefillRunner EPDecoderRunner
""" """
def __init__( def __init__(

View File

@@ -51,6 +51,7 @@ class MoEMethodBase(QuantMethodBase):
layer.hidden_size, layer.hidden_size,
layer.num_experts, layer.num_experts,
layer.fd_config.parallel_config.splitwise_role, layer.fd_config.parallel_config.splitwise_role,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
@@ -74,6 +75,7 @@ class MoEMethodBase(QuantMethodBase):
layer.hidden_size, layer.hidden_size,
layer.num_experts, layer.num_experts,
layer.fd_config.parallel_config.splitwise_role, layer.fd_config.parallel_config.splitwise_role,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size, layer.ep_size,
layer.ep_rank, layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num, layer.fd_config.model_config.redundant_experts_num,
@@ -165,8 +167,10 @@ class MoEMethodBase(QuantMethodBase):
""" """
if layer.ep_size > 1: if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase.phase == "prefill": if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate_out) return self.apply_ep_prefill(layer, x, gate_out)
else: else:
self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate_out) return self.apply_ep_decode(layer, x, gate_out)
else: else:
return self.apply_tp(layer, x, gate_out) return self.apply_tp(layer, x, gate_out)