mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-13 12:23:55 +08:00
[EP] Refactor DeepEP Engine Organization for Mixed Mode & Buffer Management Optimization (#3182)
* Add support for mixed-ep across multi nodes * code refine --------- Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
@@ -68,8 +68,7 @@ class DeepEPEngine:
|
|||||||
self.num_local_experts = num_experts // ep_size
|
self.num_local_experts = num_experts // ep_size
|
||||||
self.async_finish = async_finish
|
self.async_finish = async_finish
|
||||||
|
|
||||||
self.prefill_deepep_engine = None
|
self.deepep_engine = None
|
||||||
self.decode_deepep_engine = None
|
|
||||||
|
|
||||||
self.ep_config = Config(24, 6, 256)
|
self.ep_config = Config(24, 6, 256)
|
||||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||||
@@ -77,16 +76,12 @@ class DeepEPEngine:
|
|||||||
# In mixed EP mode on a single node, we dynamically switch between
|
# In mixed EP mode on a single node, we dynamically switch between
|
||||||
# high throughput and low latency modes.
|
# high throughput and low latency modes.
|
||||||
if splitwise_role == "mixed":
|
if splitwise_role == "mixed":
|
||||||
# decode engine
|
self.deepep_engine = deep_ep.Buffer(
|
||||||
logger.info("Initializing Low Latency Buffer")
|
|
||||||
self.get_low_latency_buffer()
|
|
||||||
# prefill engine
|
|
||||||
self.prefill_deepep_engine = deep_ep.Buffer(
|
|
||||||
self.group,
|
self.group,
|
||||||
int(5e8),
|
int(2e9),
|
||||||
0,
|
int(5e9),
|
||||||
low_latency_mode=False,
|
low_latency_mode=True,
|
||||||
num_qps_per_rank=1,
|
num_qps_per_rank=24,
|
||||||
)
|
)
|
||||||
# In disaggregated mode on mutiple nodes, we either use
|
# In disaggregated mode on mutiple nodes, we either use
|
||||||
# high throughput mode or low latency mode.
|
# high throughput mode or low latency mode.
|
||||||
@@ -95,7 +90,7 @@ class DeepEPEngine:
|
|||||||
logger.info("Initializing Low Latency Buffer")
|
logger.info("Initializing Low Latency Buffer")
|
||||||
self.get_low_latency_buffer()
|
self.get_low_latency_buffer()
|
||||||
elif moe_phase.phase == "prefill":
|
elif moe_phase.phase == "prefill":
|
||||||
self.prefill_deepep_engine = deep_ep.Buffer(
|
self.deepep_engine = deep_ep.Buffer(
|
||||||
self.group,
|
self.group,
|
||||||
int(5e8),
|
int(5e8),
|
||||||
0,
|
0,
|
||||||
@@ -124,14 +119,14 @@ class DeepEPEngine:
|
|||||||
)
|
)
|
||||||
# Allocate a buffer if not existed or not enough buffer size
|
# Allocate a buffer if not existed or not enough buffer size
|
||||||
if (
|
if (
|
||||||
self.decode_deepep_engine is None
|
self.deepep_engine is None
|
||||||
or self.decode_deepep_engine.group != self.group
|
or self.deepep_engine.group != self.group
|
||||||
or not self.decode_deepep_engine.low_latency_mode
|
or not self.deepep_engine.low_latency_mode
|
||||||
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
|
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||||
):
|
):
|
||||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||||
assert self.num_experts % self.ep_size == 0
|
assert self.num_experts % self.ep_size == 0
|
||||||
self.decode_deepep_engine = deep_ep.Buffer(
|
self.deepep_engine = deep_ep.Buffer(
|
||||||
self.group,
|
self.group,
|
||||||
0,
|
0,
|
||||||
num_rdma_bytes,
|
num_rdma_bytes,
|
||||||
@@ -168,7 +163,7 @@ class DeepEPEngine:
|
|||||||
handle,
|
handle,
|
||||||
_,
|
_,
|
||||||
dispatch_hook,
|
dispatch_hook,
|
||||||
) = self.decode_deepep_engine.low_latency_dispatch(
|
) = self.deepep_engine.low_latency_dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expertwise_scale,
|
expertwise_scale,
|
||||||
@@ -210,7 +205,7 @@ class DeepEPEngine:
|
|||||||
num_experts,
|
num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
|
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@@ -224,7 +219,7 @@ class DeepEPEngine:
|
|||||||
"""
|
"""
|
||||||
clean_low_latency_buffer
|
clean_low_latency_buffer
|
||||||
"""
|
"""
|
||||||
self.decode_deepep_engine.clean_low_latency_buffer(
|
self.deepep_engine.clean_low_latency_buffer(
|
||||||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,11 +227,7 @@ class DeepEPEngine:
|
|||||||
"""
|
"""
|
||||||
barrier_all
|
barrier_all
|
||||||
"""
|
"""
|
||||||
if self.prefill_deepep_engine is not None:
|
self.deepep_engine.barrier_all()
|
||||||
self.prefill_deepep_engine.barrier_all()
|
|
||||||
|
|
||||||
if self.decode_deepep_engine is not None:
|
|
||||||
self.decode_deepep_engine.barrier_all()
|
|
||||||
|
|
||||||
|
|
||||||
class EPRunner:
|
class EPRunner:
|
||||||
@@ -316,6 +307,9 @@ class EPRunner:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def clean_low_latency_buffer(self):
|
||||||
|
self.ep_engine.clean_low_latency_buffer()
|
||||||
|
|
||||||
|
|
||||||
class EPPrefillRunner(EPRunner):
|
class EPPrefillRunner(EPRunner):
|
||||||
"""
|
"""
|
||||||
@@ -328,6 +322,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
hidden: int,
|
hidden: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
splitwise_role: str,
|
splitwise_role: str,
|
||||||
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
redundant_experts_num: int = 0,
|
redundant_experts_num: int = 0,
|
||||||
@@ -339,7 +334,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
num_experts,
|
num_experts,
|
||||||
splitwise_role,
|
splitwise_role,
|
||||||
moe_phase,
|
moe_phase,
|
||||||
num_max_dispatch_tokens_per_rank=256,
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
redundant_experts_num=redundant_experts_num,
|
redundant_experts_num=redundant_experts_num,
|
||||||
@@ -359,7 +354,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
num_tokens_per_expert,
|
num_tokens_per_expert,
|
||||||
is_token_in_rank,
|
is_token_in_rank,
|
||||||
_,
|
_,
|
||||||
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||||
|
|
||||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||||
dispatch_args = {
|
dispatch_args = {
|
||||||
@@ -372,7 +367,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
"topk_idx": topk_idx,
|
"topk_idx": topk_idx,
|
||||||
"topk_weights": topk_weights,
|
"topk_weights": topk_weights,
|
||||||
}
|
}
|
||||||
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
|
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
self,
|
self,
|
||||||
@@ -387,14 +382,14 @@ class EPPrefillRunner(EPRunner):
|
|||||||
"async_finish": self.ep_engine.async_finish,
|
"async_finish": self.ep_engine.async_finish,
|
||||||
"topk_weights": recv_topk_weights,
|
"topk_weights": recv_topk_weights,
|
||||||
}
|
}
|
||||||
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
|
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
|
||||||
|
|
||||||
return fused_moe_out
|
return fused_moe_out
|
||||||
|
|
||||||
|
|
||||||
class EPDecoderRunner(EPRunner):
|
class EPDecoderRunner(EPRunner):
|
||||||
"""
|
"""
|
||||||
EPPrefillRunner
|
EPDecoderRunner
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@@ -51,6 +51,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
layer.fd_config.parallel_config.splitwise_role,
|
layer.fd_config.parallel_config.splitwise_role,
|
||||||
|
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||||
layer.ep_size,
|
layer.ep_size,
|
||||||
layer.ep_rank,
|
layer.ep_rank,
|
||||||
layer.fd_config.model_config.redundant_experts_num,
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
@@ -74,6 +75,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
layer.fd_config.parallel_config.splitwise_role,
|
layer.fd_config.parallel_config.splitwise_role,
|
||||||
|
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||||
layer.ep_size,
|
layer.ep_size,
|
||||||
layer.ep_rank,
|
layer.ep_rank,
|
||||||
layer.fd_config.model_config.redundant_experts_num,
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
@@ -165,8 +167,10 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||||
|
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||||
return self.apply_ep_prefill(layer, x, gate_out)
|
return self.apply_ep_prefill(layer, x, gate_out)
|
||||||
else:
|
else:
|
||||||
|
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||||
return self.apply_ep_decode(layer, x, gate_out)
|
return self.apply_ep_decode(layer, x, gate_out)
|
||||||
else:
|
else:
|
||||||
return self.apply_tp(layer, x, gate_out)
|
return self.apply_tp(layer, x, gate_out)
|
||||||
|
Reference in New Issue
Block a user