mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] [Optimization] [EP] EP communication optimization. (#5145)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -19,17 +19,15 @@ from abc import abstractmethod
|
||||
import deep_ep
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
class DeepEPEngineBase:
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
Base class for DeepEP engine implementations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -45,7 +43,7 @@ class DeepEPEngine:
|
||||
group=None,
|
||||
):
|
||||
"""
|
||||
Initialize the DeepEP engine.
|
||||
Initialize the DeepEP engine base.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
ep_size: The number of ranks.
|
||||
@@ -68,27 +66,48 @@ class DeepEPEngine:
|
||||
self.group = group
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.deepep_engine = None
|
||||
self.init_deepep_engine()
|
||||
|
||||
def init_deepep_engine(self):
|
||||
if self.splitwise_role == "mixed" or self.moe_phase.phase == "prefill":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(1e9),
|
||||
0,
|
||||
num_experts=self.num_experts,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
elif self.moe_phase.phase == "decode":
|
||||
logger.info("Initializing Low Latency Buffer")
|
||||
self.get_low_latency_buffer()
|
||||
def barrier_all(self):
|
||||
"""
|
||||
barrier_all
|
||||
"""
|
||||
if self.deepep_engine is not None:
|
||||
self.deepep_engine.barrier_all()
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase {self.moe_phase}")
|
||||
raise RuntimeError("The deepep engine has not been initialized yet.")
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngineHighThroughput(DeepEPEngineBase):
|
||||
"""
|
||||
High throughput version of DeepEP engine for prefill phase.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(1e9),
|
||||
0,
|
||||
num_experts=self.num_experts,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngineLowLatency(DeepEPEngineBase):
|
||||
"""
|
||||
Low latency version of DeepEP engine for decode phase.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.get_low_latency_buffer()
|
||||
|
||||
def get_low_latency_buffer(self):
|
||||
"""
|
||||
Get the DeepEP buffer.
|
||||
Initialize low latency buffer for decode phase.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
@@ -103,23 +122,16 @@ class DeepEPEngine:
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if (
|
||||
self.deepep_engine is None
|
||||
or self.deepep_engine.group != self.group
|
||||
or not self.deepep_engine.low_latency_mode
|
||||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
self.num_experts,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=self.num_experts // self.num_ranks,
|
||||
)
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
self.num_experts,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=self.num_experts // self.ep_size,
|
||||
)
|
||||
|
||||
def low_latency_dispatch(
|
||||
self,
|
||||
@@ -172,7 +184,6 @@ class DeepEPEngine:
|
||||
handle,
|
||||
):
|
||||
"""
|
||||
|
||||
Return:
|
||||
combined_hidden_states: [num_tokens, hidden_size]
|
||||
"""
|
||||
@@ -192,12 +203,6 @@ class DeepEPEngine:
|
||||
"""
|
||||
pass
|
||||
|
||||
def barrier_all(self):
|
||||
"""
|
||||
barrier_all
|
||||
"""
|
||||
self.deepep_engine.barrier_all()
|
||||
|
||||
|
||||
class XPUEPRunner:
|
||||
"""
|
||||
@@ -227,10 +232,15 @@ class XPUEPRunner:
|
||||
self.ep_rank = ep_rank
|
||||
self.redundant_experts_num = redundant_experts_num
|
||||
self.ep_group = ep_group
|
||||
self.ep_engine = None
|
||||
self.init_ep_engine()
|
||||
|
||||
def init_ep_engine(self):
|
||||
self.ep_engine = DeepEPEngine(
|
||||
"""Initialize the EP engine with default implementation"""
|
||||
self._init_ep_engine(self._get_engine_class())
|
||||
|
||||
def _init_ep_engine(self, engine_class):
|
||||
self.ep_engine = engine_class(
|
||||
num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank,
|
||||
hidden_size=self.hidden_size,
|
||||
num_experts=self.num_experts + self.redundant_experts_num,
|
||||
@@ -241,6 +251,11 @@ class XPUEPRunner:
|
||||
group=self.ep_group,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_engine_class(self):
|
||||
"""Get the engine class to be initialized"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
||||
"""
|
||||
moe_select
|
||||
@@ -325,6 +340,9 @@ class XPUEPPrefillRunner(XPUEPRunner):
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
def _get_engine_class(self):
|
||||
return DeepEPEngineHighThroughput
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
@@ -389,6 +407,9 @@ class XPUEPDecoderRunner(XPUEPRunner):
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
def _get_engine_class(self):
|
||||
return DeepEPEngineLowLatency
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
|
||||
@@ -342,6 +342,30 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
|
||||
|
||||
def only_decode(self):
|
||||
"""
|
||||
Update Batch type for if_only_decode.
|
||||
"""
|
||||
if_only_decode = True
|
||||
prefill_exists = None
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
no_need_stop_list = []
|
||||
no_need_stop = self.not_need_stop()
|
||||
paddle.distributed.all_gather_object(no_need_stop_list, not no_need_stop)
|
||||
if_all_device_empty = all(no_need_stop_list)
|
||||
if if_all_device_empty:
|
||||
if_only_decode = False
|
||||
else:
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
if_only_decode = all(only_decode_batch_list)
|
||||
|
||||
if_only_decode = if_only_decode and not (
|
||||
prefill_exists if prefill_exists is not None else self.exist_prefill()
|
||||
)
|
||||
return if_only_decode
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
@@ -898,8 +922,16 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
self.initialize_attention_backend()
|
||||
|
||||
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
|
||||
self.forward_meta.kv_signal_sender = self.kv_signal_sender
|
||||
|
||||
if (
|
||||
self.fd_config.scheduler_config.splitwise_role == "mixed"
|
||||
): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point.
|
||||
if_only_decode = self.only_decode()
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
|
||||
# Get sampling metadata
|
||||
# TODU(lilujia): sync with GPU
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
|
||||
@@ -445,6 +445,7 @@ export BKCL_PCIE_RING=1
|
||||
export XSHMEM_MODE=1
|
||||
export XSHMEM_QP_NUM_PER_RANK=32
|
||||
export BKCL_RDMA_VERBS=1
|
||||
export MOE_FFN_USE_DENSE_INPUT=1
|
||||
|
||||
wget -q https://paddle-qa.bj.bcebos.com/xpu_third_party/xDeepEP.tar.gz
|
||||
tar -xzf xDeepEP.tar.gz
|
||||
@@ -511,6 +512,7 @@ unset BKCL_PCIE_RING
|
||||
unset XSHMEM_MODE
|
||||
unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset MOE_FFN_USE_DENSE_INPUT
|
||||
stop_processes >kill.log 2>&1
|
||||
|
||||
if [ ${ep_online_exit_code} -ne 0 ]; then
|
||||
@@ -540,6 +542,7 @@ export BKCL_PCIE_RING=1
|
||||
export XSHMEM_MODE=1
|
||||
export XSHMEM_QP_NUM_PER_RANK=32
|
||||
export BKCL_RDMA_VERBS=1
|
||||
export MOE_FFN_USE_DENSE_INPUT=1
|
||||
|
||||
export port_num=$((8188 + XPU_ID * 100))
|
||||
# 启动服务
|
||||
@@ -597,6 +600,7 @@ unset BKCL_PCIE_RING
|
||||
unset XSHMEM_MODE
|
||||
unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset MOE_FFN_USE_DENSE_INPUT
|
||||
stop_processes >kill.log 2>&1
|
||||
|
||||
if [ ${ep_online_exit_code} -ne 0 ]; then
|
||||
@@ -627,6 +631,7 @@ export BKCL_PCIE_RING=1
|
||||
export XSHMEM_MODE=1
|
||||
export XSHMEM_QP_NUM_PER_RANK=32
|
||||
export BKCL_RDMA_VERBS=1
|
||||
export MOE_FFN_USE_DENSE_INPUT=1
|
||||
|
||||
export port_num=$((8188 + XPU_ID * 100))
|
||||
# 启动服务
|
||||
@@ -686,6 +691,7 @@ unset BKCL_PCIE_RING
|
||||
unset XSHMEM_MODE
|
||||
unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset MOE_FFN_USE_DENSE_INPUT
|
||||
stop_processes >kill.log 2>&1
|
||||
|
||||
if [ ${ep_online_exit_code} -ne 0 ]; then
|
||||
|
||||
@@ -284,12 +284,13 @@ def setup_ep_env():
|
||||
"""
|
||||
env_vars = {
|
||||
"BKCL_ENABLE_XDR": "1",
|
||||
"BKCL_RDMA_NICS": "xgbe1,xgbe2,xgbe3,xgbe4",
|
||||
"BKCL_RDMA_NICS": "eth1,eth1,eth2,eth2",
|
||||
"BKCL_TRACE_TOPO": "1",
|
||||
"BKCL_PCIE_RING": "1",
|
||||
"XSHMEM_MODE": "1",
|
||||
"XSHMEM_QP_NUM_PER_RANK": "32",
|
||||
"BKCL_RDMA_VERBS": "1",
|
||||
"MOE_FFN_USE_DENSE_INPUT": "1",
|
||||
}
|
||||
|
||||
# 保存原始值
|
||||
|
||||
@@ -81,8 +81,6 @@ def test_ep4tp4_all2all(xpu_env):
|
||||
str(port_num + 47873),
|
||||
"--gpu-memory-utilization",
|
||||
"0.9",
|
||||
"--load-choices",
|
||||
"default",
|
||||
]
|
||||
|
||||
# 启动服务器
|
||||
|
||||
@@ -81,8 +81,6 @@ def test_ep4tp4_online(xpu_env):
|
||||
"--disable-sequence-parallel-moe",
|
||||
"--gpu-memory-utilization",
|
||||
"0.9",
|
||||
"--load-choices",
|
||||
"default",
|
||||
]
|
||||
|
||||
# 启动服务器
|
||||
|
||||
Reference in New Issue
Block a user