[XPU] [Optimization] [EP] EP communication optimization. (#5145)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
zccjjj
2025-12-05 10:03:45 +08:00
committed by GitHub
parent 620d1da1c9
commit e927c65742
6 changed files with 107 additions and 51 deletions

View File

@@ -19,17 +19,15 @@ from abc import abstractmethod
import deep_ep
import paddle
from paddle import nn
from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.config import MoEPhase
from fastdeploy.utils import singleton
@singleton
class DeepEPEngine:
class DeepEPEngineBase:
"""
A wrapper class for DeepEP engine.
Base class for DeepEP engine implementations.
"""
def __init__(
@@ -45,7 +43,7 @@ class DeepEPEngine:
group=None,
):
"""
Initialize the DeepEP engine.
Initialize the DeepEP engine base.
Args:
group: The MPI group object.
ep_size: The number of ranks.
@@ -68,27 +66,48 @@ class DeepEPEngine:
self.group = group
self.num_local_experts = num_experts // ep_size
self.deepep_engine = None
self.init_deepep_engine()
def init_deepep_engine(self):
if self.splitwise_role == "mixed" or self.moe_phase.phase == "prefill":
self.deepep_engine = deep_ep.Buffer(
self.group,
int(1e9),
0,
num_experts=self.num_experts,
low_latency_mode=False,
num_qps_per_rank=1,
)
elif self.moe_phase.phase == "decode":
logger.info("Initializing Low Latency Buffer")
self.get_low_latency_buffer()
def barrier_all(self):
"""
barrier_all
"""
if self.deepep_engine is not None:
self.deepep_engine.barrier_all()
else:
raise ValueError(f"Unknown generation phase {self.moe_phase}")
raise RuntimeError("The deepep engine has not been initialized yet.")
@singleton
class DeepEPEngineHighThroughput(DeepEPEngineBase):
"""
High throughput version of DeepEP engine for prefill phase.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.deepep_engine = deep_ep.Buffer(
self.group,
int(1e9),
0,
num_experts=self.num_experts,
low_latency_mode=False,
num_qps_per_rank=1,
)
@singleton
class DeepEPEngineLowLatency(DeepEPEngineBase):
"""
Low latency version of DeepEP engine for decode phase.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.get_low_latency_buffer()
def get_low_latency_buffer(self):
"""
Get the DeepEP buffer.
Initialize low latency buffer for decode phase.
Args:
group: The MPI group object.
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
@@ -103,23 +122,16 @@ class DeepEPEngine:
self.ep_size,
self.num_experts,
)
# Allocate a buffer if not existed or not enough buffer size
if (
self.deepep_engine is None
or self.deepep_engine.group != self.group
or not self.deepep_engine.low_latency_mode
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
):
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
self.num_experts,
low_latency_mode=True,
num_qps_per_rank=self.num_experts // self.num_ranks,
)
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert self.num_experts % self.ep_size == 0
self.deepep_engine = deep_ep.Buffer(
self.group,
0,
num_rdma_bytes,
self.num_experts,
low_latency_mode=True,
num_qps_per_rank=self.num_experts // self.ep_size,
)
def low_latency_dispatch(
self,
@@ -172,7 +184,6 @@ class DeepEPEngine:
handle,
):
"""
Return:
combined_hidden_states: [num_tokens, hidden_size]
"""
@@ -192,12 +203,6 @@ class DeepEPEngine:
"""
pass
def barrier_all(self):
"""
barrier_all
"""
self.deepep_engine.barrier_all()
class XPUEPRunner:
"""
@@ -227,10 +232,15 @@ class XPUEPRunner:
self.ep_rank = ep_rank
self.redundant_experts_num = redundant_experts_num
self.ep_group = ep_group
self.ep_engine = None
self.init_ep_engine()
def init_ep_engine(self):
self.ep_engine = DeepEPEngine(
"""Initialize the EP engine with default implementation"""
self._init_ep_engine(self._get_engine_class())
def _init_ep_engine(self, engine_class):
self.ep_engine = engine_class(
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
hidden_size=self.hidden_size,
num_experts=self.num_experts + self.redundant_experts_num,
@@ -241,6 +251,11 @@ class XPUEPRunner:
group=self.ep_group,
)
@abstractmethod
def _get_engine_class(self):
"""Get the engine class to be initialized"""
raise NotImplementedError("Subclasses must implement this method")
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
"""
moe_select
@@ -325,6 +340,9 @@ class XPUEPPrefillRunner(XPUEPRunner):
ep_group=ep_group,
)
def _get_engine_class(self):
return DeepEPEngineHighThroughput
def dispatch(
self,
x: paddle.Tensor,
@@ -389,6 +407,9 @@ class XPUEPDecoderRunner(XPUEPRunner):
ep_group=ep_group,
)
def _get_engine_class(self):
return DeepEPEngineLowLatency
def dispatch(
self,
x: paddle.Tensor,

View File

@@ -342,6 +342,30 @@ class XPUModelRunner(ModelRunnerBase):
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
def only_decode(self):
"""
Update Batch type for if_only_decode.
"""
if_only_decode = True
prefill_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
no_need_stop_list = []
no_need_stop = self.not_need_stop()
paddle.distributed.all_gather_object(no_need_stop_list, not no_need_stop)
if_all_device_empty = all(no_need_stop_list)
if if_all_device_empty:
if_only_decode = False
else:
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
if_only_decode = all(only_decode_batch_list)
if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
return if_only_decode
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
@@ -898,8 +922,16 @@ class XPUModelRunner(ModelRunnerBase):
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.forward_meta.kv_signal_sender = self.kv_signal_sender
if (
self.fd_config.scheduler_config.splitwise_role == "mixed"
): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point.
if_only_decode = self.only_decode()
self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
# Get sampling metadata
# TODU(lilujia): sync with GPU
self.sampling_metadata = SamplingMetadata(

View File

@@ -445,6 +445,7 @@ export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1
wget -q https://paddle-qa.bj.bcebos.com/xpu_third_party/xDeepEP.tar.gz
tar -xzf xDeepEP.tar.gz
@@ -511,6 +512,7 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
stop_processes >kill.log 2>&1
if [ ${ep_online_exit_code} -ne 0 ]; then
@@ -540,6 +542,7 @@ export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1
export port_num=$((8188 + XPU_ID * 100))
# 启动服务
@@ -597,6 +600,7 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
stop_processes >kill.log 2>&1
if [ ${ep_online_exit_code} -ne 0 ]; then
@@ -627,6 +631,7 @@ export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1
export port_num=$((8188 + XPU_ID * 100))
# 启动服务
@@ -686,6 +691,7 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
stop_processes >kill.log 2>&1
if [ ${ep_online_exit_code} -ne 0 ]; then

View File

@@ -284,12 +284,13 @@ def setup_ep_env():
"""
env_vars = {
"BKCL_ENABLE_XDR": "1",
"BKCL_RDMA_NICS": "xgbe1,xgbe2,xgbe3,xgbe4",
"BKCL_RDMA_NICS": "eth1,eth1,eth2,eth2",
"BKCL_TRACE_TOPO": "1",
"BKCL_PCIE_RING": "1",
"XSHMEM_MODE": "1",
"XSHMEM_QP_NUM_PER_RANK": "32",
"BKCL_RDMA_VERBS": "1",
"MOE_FFN_USE_DENSE_INPUT": "1",
}
# 保存原始值

View File

@@ -81,8 +81,6 @@ def test_ep4tp4_all2all(xpu_env):
str(port_num + 47873),
"--gpu-memory-utilization",
"0.9",
"--load-choices",
"default",
]
# 启动服务器

View File

@@ -81,8 +81,6 @@ def test_ep4tp4_online(xpu_env):
"--disable-sequence-parallel-moe",
"--gpu-memory-utilization",
"0.9",
"--load-choices",
"default",
]
# 启动服务器