support ep+tp at op layer (#4688)

This commit is contained in:
zhupengyang
2025-11-05 11:15:57 +08:00
committed by GitHub
parent 937eb3c6ed
commit 2fd254e5b7
8 changed files with 138 additions and 105 deletions

View File

@@ -311,14 +311,9 @@ class XPUMoEMethod(MoEMethodBase):
apply tp
"""
if self.moe_quant_type in ["w16a16"]:
using_ep_moe_algo = False
else:
using_ep_moe_algo = True
if using_ep_moe_algo:
fused_moe_out = self.apply_tp_scatter_op(layer, x, gate)
else:
fused_moe_out = self.apply_tp_fused_op(layer, x, gate)
else:
fused_moe_out = self.apply_tp_scatter_op(layer, x, gate)
return fused_moe_out

View File

@@ -564,6 +564,29 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
"""
Forward split allgather function.
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
token_num_per_rank = (token_num + tp_size - 1) // tp_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
start_offset = tp_rank * token_num_per_rank
end_offset = (tp_rank + 1) * token_num_per_rank
if start_offset >= token_num:
start_offset = token_num
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate)
multi_outs = paddle.zeros([token_num_per_rank * tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer):
"""
Defines the forward computation of the moe layer.
@@ -575,5 +598,10 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate)
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size:
out = self.forward_split_allgather(x, gate)
else:
out = self.quant_method.apply(self, x, gate)
return out

View File

@@ -111,14 +111,6 @@ class Ernie4_5_MoE(nn.Layer):
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
@@ -221,30 +213,8 @@ class Ernie4_5_MoE(nn.Layer):
def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, hidden_states: paddle.Tensor):
token_num = hidden_states.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(hidden_states, token_num)
else:
out = self.experts(hidden_states, self.gate)
out = self.experts(hidden_states, self.gate)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x

View File

@@ -56,14 +56,6 @@ class Qwen3MoeBlock(nn.Layer):
) -> None:
super().__init__()
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1
weight_key_map = {
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
@@ -87,31 +79,8 @@ class Qwen3MoeBlock(nn.Layer):
weight_dtype="float32",
)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, x):
token_num = x.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(x, token_num)
else:
out = self.experts(x, self.gate)
return out
return self.experts(x, self.gate)
def load_state_dict(self, state_dict):
""" """

View File

@@ -1156,6 +1156,13 @@ class XPUModelRunner(ModelRunnerBase):
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop() and not is_dummy_run:
self._execute_empty_input()
return None
# 2. Padding inputs for cuda grph
# 3. Execute model
@@ -1225,6 +1232,17 @@ class XPUModelRunner(ModelRunnerBase):
return None
def _execute_empty_input(self) -> None:
"""
In certain scenarios, such as during EP,
the runner needs to execute partial modules of the model without input data.
This requires the model to implement the `empty_input_forward` method.
"""
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()
else:
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
@profile_run_guard(True)
def profile_run(self) -> None:
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""

View File

@@ -110,8 +110,9 @@ class XpuWorker(WorkerBase):
free_memory = xpu_get_free_global_memory(self.device_id)
logger.info(
f"Before warm up, total_memory: {total_memory / 1024**3}GB--------, \
used_memory: {used_memory / 1024**3}GB--------, free_memory: {free_memory / 1024**3}GB--------"
f"Before warm up, total_memory: {total_memory / 1024**3}GB, "
f"used_memory: {used_memory / 1024**3}GB, "
f"free_memory: {free_memory / 1024**3}GB."
)
if self.parallel_config.use_ep:
@@ -131,8 +132,9 @@ class XpuWorker(WorkerBase):
self.model_runner.clear_block_table()
logger.info(
f"After warm up, total_available_memory: {total_available_memory / 1024**3}GB--------, \
used_memory: {used_memory / 1024**3}GB--------, available_kv_cache_memory: {available_kv_cache_memory / 1024**3}GB--------"
f"After warm up, total_available_memory: {total_available_memory / 1024**3}GB, "
f"used_memory: {used_memory / 1024**3}GB, "
f"available_kv_cache_memory: {available_kv_cache_memory / 1024**3}GB."
)
paddle.device.xpu.empty_cache()
return available_kv_cache_memory # approximate value

View File

@@ -6,10 +6,14 @@ echo "$DIR"
apt install -y lsof
#先kill一遍
ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
function stop_processes() {
ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
}
stop_processes
#设置模型路径
export model_path=${MODEL_PATH}/ERNIE-4.5-300B-A47B-Paddle
@@ -38,6 +42,8 @@ unset http_proxy
unset https_proxy
unset no_proxy
stop_processes
# 起服务
rm -rf log/*
rm -f core*
@@ -71,7 +77,10 @@ while true; do
# 超时判断
if [ $ELAPSED -ge $TIMEOUT ]; then
echo -e "\n服务启动超时经过 $((TIMEOUT/60)) 分钟服务仍未启动!"
stop_processes
echo "server.log"
cat server.log
echo "log/workerlog.0"
cat log/workerlog.0
exit 1
fi
@@ -93,10 +102,7 @@ python -m pytest tests/ci_use/XPU_45T/run_45T.py
kv_block_test_exit_code=$?
echo kv_block_test_exit_code is ${kv_block_test_exit_code}
ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
stop_processes
if [ ${kv_block_test_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
@@ -139,7 +145,10 @@ while true; do
# 超时判断
if [ $ELAPSED -ge $TIMEOUT ]; then
echo -e "\n服务启动超时经过 $((TIMEOUT/60)) 分钟服务仍未启动!"
stop_processes
echo "server.log"
cat server.log
echo "log/workerlog.0"
cat log/workerlog.0
exit 1
fi
@@ -161,10 +170,7 @@ python -m pytest tests/ci_use/XPU_45T/run_w4a8.py
w4a8_test_exit_code=$?
echo w4a8_test_exit_code is ${w4a8_test_exit_code}
ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
stop_processes
if [ ${w4a8_test_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
@@ -210,7 +216,10 @@ while true; do
# 超时判断
if [ $ELAPSED -ge $TIMEOUT ]; then
echo -e "\n服务启动超时经过 $((TIMEOUT/60)) 分钟服务仍未启动!"
stop_processes
echo "server.log"
cat server.log
echo "log/workerlog.0"
cat log/workerlog.0
exit 1
fi
@@ -232,10 +241,7 @@ python -m pytest tests/ci_use/XPU_45T/run_45vl.py
vl_test_exit_code=$?
echo vl_test_exit_code is ${vl_test_exit_code}
ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
stop_processes
if [ ${vl_test_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
@@ -245,12 +251,13 @@ if [ ${vl_test_exit_code} -ne 0 ]; then
fi
echo "============================开始EP并行测试!============================"
echo "============================开始 EP4TP1 测试!============================"
sleep 5
rm -rf log/*
rm -f core*
ipcrm --all=msg
xpu-smi
export XPU_VISIBLE_DEVICES="0,1,2,3"
export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export BKCL_ENABLE_XDR=1
export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4
export BKCL_TRACE_TOPO=1
@@ -265,6 +272,9 @@ cd xDeepEP
bash build.sh
cd -
export enable_expert_parallel=1
export enable_tensor_parallel=0
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
ep_exit_code=$?
@@ -275,14 +285,49 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
stop_processes
if [ ${ep_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
cat log/workerlog.0
echo "EP并行 相关测试失败请检查pr代码"
echo "EP4TP1 相关测试失败请检查pr代码"
exit 1
fi
echo "============================开始 EP4TP4 测试!============================"
sleep 5
rm -rf log/*
rm -f core*
ipcrm --all=msg
xpu-smi
export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export BKCL_ENABLE_XDR=1
export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4
export BKCL_TRACE_TOPO=1
export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export enable_expert_parallel=1
export enable_tensor_parallel=1
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
ep_exit_code=$?
unset BKCL_ENABLE_XDR
unset BKCL_RDMA_NICS
unset BKCL_TRACE_TOPO
unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
stop_processes
if [ ${ep_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
cat log/workerlog.0
echo "EP4TP4 相关测试失败请检查pr代码"
exit 1
fi

View File

@@ -1,31 +1,39 @@
import os
import psutil
from paddleformers.trainer import strtobool
from fastdeploy import LLM, SamplingParams
def test_fd_ep():
""" """
msg1 = [
{"role": "system", "content": ""},
{"role": "user", "content": "北京天安门广场在哪里?"},
]
messages = [msg1]
print(f"[INFO] messages: {messages}")
# 采样参数
sampling_params = SamplingParams(top_p=0, max_tokens=500)
# 模型路径与设备配置
model = os.getenv("model_path", "/home/ERNIE-4.5-300B-A47B-Paddle")
model_root = os.getenv("MODEL_PATH", "/home")
model = f"{model_root}/ERNIE-4.5-300B-A47B-Paddle"
xpu_visible_devices = os.getenv("XPU_VISIBLE_DEVICES", "0")
xpu_device_num = len(xpu_visible_devices.split(","))
enable_expert_parallel = True
enable_expert_parallel = strtobool(os.getenv("enable_expert_parallel", "1"))
enable_tensor_parallel = strtobool(os.getenv("enable_tensor_parallel", "0"))
print(f"enable_expert_parallel: {enable_expert_parallel}, enable_tensor_parallel: {enable_tensor_parallel}")
if enable_expert_parallel:
tensor_parallel_size = 1
data_parallel_size = xpu_device_num
if enable_tensor_parallel:
tensor_parallel_size = xpu_device_num
data_parallel_size = 1
else:
tensor_parallel_size = 1
data_parallel_size = xpu_device_num
else:
tensor_parallel_size = xpu_device_num
data_parallel_size = 1
@@ -33,8 +41,6 @@ def test_fd_ep():
engine_worker_queue_port = [str(8023 + i * 10) for i in range(data_parallel_size)]
engine_worker_queue_port = ",".join(engine_worker_queue_port)
print(f"[INFO] messages: {messages}")
llm = LLM(
model=model,
enable_expert_parallel=enable_expert_parallel,