diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 20f47cf36..098a622a4 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -311,14 +311,9 @@ class XPUMoEMethod(MoEMethodBase): apply tp """ if self.moe_quant_type in ["w16a16"]: - using_ep_moe_algo = False - else: - using_ep_moe_algo = True - - if using_ep_moe_algo: - fused_moe_out = self.apply_tp_scatter_op(layer, x, gate) - else: fused_moe_out = self.apply_tp_fused_op(layer, x, gate) + else: + fused_moe_out = self.apply_tp_scatter_op(layer, x, gate) return fused_moe_out diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 09330e549..6b0e8f386 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -564,6 +564,29 @@ class FusedMoE(nn.Layer): else: self.quant_method.process_loaded_weights(self, state_dict) + def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer): + """ + Forward split allgather function. + """ + token_num = x.shape[0] + tp_size = self.fd_config.parallel_config.tensor_parallel_size + tp_rank = self.fd_config.parallel_config.tensor_parallel_rank + token_num_per_rank = (token_num + tp_size - 1) // tp_size + # AllGather will hang when the data shapes on multi-ranks are different! + part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype) + start_offset = tp_rank * token_num_per_rank + end_offset = (tp_rank + 1) * token_num_per_rank + if start_offset >= token_num: + start_offset = token_num + if end_offset > token_num: + end_offset = token_num + part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :] + out = self.quant_method.apply(self, part_x, gate) + multi_outs = paddle.zeros([token_num_per_rank * tp_size, x.shape[1]], dtype=x.dtype) + paddle.distributed.all_gather(multi_outs, out, self.tp_group) + out = multi_outs[:token_num, :] + return out + def forward(self, x: paddle.Tensor, gate: nn.Layer): """ Defines the forward computation of the moe layer. @@ -575,5 +598,10 @@ class FusedMoE(nn.Layer): Tensor: Output tensor.s """ - out = self.quant_method.apply(self, x, gate) + token_num = x.shape[0] + tp_size = self.fd_config.parallel_config.tensor_parallel_size + if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size: + out = self.forward_split_allgather(x, gate) + else: + out = self.quant_method.apply(self, x, gate) return out diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index c2baeb910..94f3a57a6 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -111,14 +111,6 @@ class Ernie4_5_MoE(nn.Layer): if hasattr(fd_config.quant_config, "moe_quant_type"): moe_quant_type = fd_config.quant_config.moe_quant_type - self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size - self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size - self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank - self.tp_group = fd_config.parallel_config.tp_group - - self.use_ep = self.expert_parallel_size > 1 - self.use_tp = self.tensor_parallel_size > 1 - if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8": weight_key_map = { "gate_weight_key": f"{prefix}.gate.weight", @@ -221,30 +213,8 @@ class Ernie4_5_MoE(nn.Layer): def update_state_dict(self, state_dict): self.fused_moe.load_state_dict(state_dict, True) - def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int): - token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size - # AllGather will hang when the data shapes on multi-ranks are different! - part_hidden_states = paddle.zeros( - shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype - ) - start_offset = self.tensor_parallel_rank * token_num_per_rank - end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank - if end_offset > token_num: - end_offset = token_num - part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :] - out = self.experts(part_hidden_states, self.gate) - multi_outs = [] - paddle.distributed.all_gather(multi_outs, out, self.tp_group) - out = paddle.concat(multi_outs, axis=0) - out = out[:token_num, :] - return out - def forward(self, hidden_states: paddle.Tensor): - token_num = hidden_states.shape[0] - if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size: - out = self.split_allgather_out(hidden_states, token_num) - else: - out = self.experts(hidden_states, self.gate) + out = self.experts(hidden_states, self.gate) if self.num_shared_experts > 0: s_x = self.shared_experts(hidden_states) out = out + s_x diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 8e47a919b..12f02282e 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -56,14 +56,6 @@ class Qwen3MoeBlock(nn.Layer): ) -> None: super().__init__() - self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size - self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size - self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank - self.tp_group = fd_config.parallel_config.tp_group - - self.use_ep = self.expert_parallel_size > 1 - self.use_tp = self.tensor_parallel_size > 1 - weight_key_map = { "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", @@ -87,31 +79,8 @@ class Qwen3MoeBlock(nn.Layer): weight_dtype="float32", ) - def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int): - token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size - # AllGather will hang when the data shapes on multi-ranks are different! - part_hidden_states = paddle.zeros( - shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype - ) - start_offset = self.tensor_parallel_rank * token_num_per_rank - end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank - if end_offset > token_num: - end_offset = token_num - part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :] - out = self.experts(part_hidden_states, self.gate) - multi_outs = [] - paddle.distributed.all_gather(multi_outs, out, self.tp_group) - out = paddle.concat(multi_outs, axis=0) - out = out[:token_num, :] - return out - def forward(self, x): - token_num = x.shape[0] - if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size: - out = self.split_allgather_out(x, token_num) - else: - out = self.experts(x, self.gate) - return out + return self.experts(x, self.gate) def load_state_dict(self, state_dict): """ """ diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 03011433c..41cf0cc89 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1156,6 +1156,13 @@ class XPUModelRunner(ModelRunnerBase): # 1. Prepare inputs of model and decoder. self._prepare_inputs(is_dummy_run=is_dummy_run) + # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. + # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, + # when there is data on other runner, the current runner is required to execute part of the model. + if not self.not_need_stop() and not is_dummy_run: + self._execute_empty_input() + return None + # 2. Padding inputs for cuda grph # 3. Execute model @@ -1225,6 +1232,17 @@ class XPUModelRunner(ModelRunnerBase): return None + def _execute_empty_input(self) -> None: + """ + In certain scenarios, such as during EP, + the runner needs to execute partial modules of the model without input data. + This requires the model to implement the `empty_input_forward` method. + """ + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() + else: + raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") + @profile_run_guard(True) def profile_run(self) -> None: """Execute a forward pass with dummy inputs to profile the memory usage of the model""" diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index 4746987fa..1bf2cde3f 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -110,8 +110,9 @@ class XpuWorker(WorkerBase): free_memory = xpu_get_free_global_memory(self.device_id) logger.info( - f"Before warm up, total_memory: {total_memory / 1024**3}GB--------, \ - used_memory: {used_memory / 1024**3}GB--------, free_memory: {free_memory / 1024**3}GB--------" + f"Before warm up, total_memory: {total_memory / 1024**3}GB, " + f"used_memory: {used_memory / 1024**3}GB, " + f"free_memory: {free_memory / 1024**3}GB." ) if self.parallel_config.use_ep: @@ -131,8 +132,9 @@ class XpuWorker(WorkerBase): self.model_runner.clear_block_table() logger.info( - f"After warm up, total_available_memory: {total_available_memory / 1024**3}GB--------, \ - used_memory: {used_memory / 1024**3}GB--------, available_kv_cache_memory: {available_kv_cache_memory / 1024**3}GB--------" + f"After warm up, total_available_memory: {total_available_memory / 1024**3}GB, " + f"used_memory: {used_memory / 1024**3}GB, " + f"available_kv_cache_memory: {available_kv_cache_memory / 1024**3}GB." ) paddle.device.xpu.empty_cache() return available_kv_cache_memory # approximate value diff --git a/scripts/run_ci_xpu.sh b/scripts/run_ci_xpu.sh index f2be14195..d59996372 100644 --- a/scripts/run_ci_xpu.sh +++ b/scripts/run_ci_xpu.sh @@ -6,10 +6,14 @@ echo "$DIR" apt install -y lsof #先kill一遍 -ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -lsof -t -i :8188 | xargs kill -9 || true +function stop_processes() { + ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true + ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true + ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true + lsof -t -i :8188 | xargs kill -9 || true +} +stop_processes + #设置模型路径 export model_path=${MODEL_PATH}/ERNIE-4.5-300B-A47B-Paddle @@ -38,6 +42,8 @@ unset http_proxy unset https_proxy unset no_proxy +stop_processes + # 起服务 rm -rf log/* rm -f core* @@ -71,7 +77,10 @@ while true; do # 超时判断 if [ $ELAPSED -ge $TIMEOUT ]; then echo -e "\n服务启动超时:经过 $((TIMEOUT/60)) 分钟服务仍未启动!" + stop_processes + echo "server.log" cat server.log + echo "log/workerlog.0" cat log/workerlog.0 exit 1 fi @@ -93,10 +102,7 @@ python -m pytest tests/ci_use/XPU_45T/run_45T.py kv_block_test_exit_code=$? echo kv_block_test_exit_code is ${kv_block_test_exit_code} -ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -lsof -t -i :8188 | xargs kill -9 || true +stop_processes if [ ${kv_block_test_exit_code} -ne 0 ]; then echo "log/workerlog.0" @@ -139,7 +145,10 @@ while true; do # 超时判断 if [ $ELAPSED -ge $TIMEOUT ]; then echo -e "\n服务启动超时:经过 $((TIMEOUT/60)) 分钟服务仍未启动!" + stop_processes + echo "server.log" cat server.log + echo "log/workerlog.0" cat log/workerlog.0 exit 1 fi @@ -161,10 +170,7 @@ python -m pytest tests/ci_use/XPU_45T/run_w4a8.py w4a8_test_exit_code=$? echo w4a8_test_exit_code is ${w4a8_test_exit_code} -ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -lsof -t -i :8188 | xargs kill -9 || true +stop_processes if [ ${w4a8_test_exit_code} -ne 0 ]; then echo "log/workerlog.0" @@ -210,7 +216,10 @@ while true; do # 超时判断 if [ $ELAPSED -ge $TIMEOUT ]; then echo -e "\n服务启动超时:经过 $((TIMEOUT/60)) 分钟服务仍未启动!" + stop_processes + echo "server.log" cat server.log + echo "log/workerlog.0" cat log/workerlog.0 exit 1 fi @@ -232,10 +241,7 @@ python -m pytest tests/ci_use/XPU_45T/run_45vl.py vl_test_exit_code=$? echo vl_test_exit_code is ${vl_test_exit_code} -ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -lsof -t -i :8188 | xargs kill -9 || true +stop_processes if [ ${vl_test_exit_code} -ne 0 ]; then echo "log/workerlog.0" @@ -245,12 +251,13 @@ if [ ${vl_test_exit_code} -ne 0 ]; then fi -echo "============================开始EP并行测试!============================" +echo "============================开始 EP4TP1 测试!============================" sleep 5 rm -rf log/* rm -f core* +ipcrm --all=msg xpu-smi -export XPU_VISIBLE_DEVICES="0,1,2,3" +export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" export BKCL_ENABLE_XDR=1 export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4 export BKCL_TRACE_TOPO=1 @@ -265,6 +272,9 @@ cd xDeepEP bash build.sh cd - +export enable_expert_parallel=1 +export enable_tensor_parallel=0 + python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py ep_exit_code=$? @@ -275,14 +285,49 @@ unset BKCL_PCIE_RING unset XSHMEM_MODE unset XSHMEM_QP_NUM_PER_RANK unset BKCL_RDMA_VERBS -ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true -lsof -t -i :8188 | xargs kill -9 || true +stop_processes if [ ${ep_exit_code} -ne 0 ]; then echo "log/workerlog.0" cat log/workerlog.0 - echo "EP并行 相关测试失败,请检查pr代码" + echo "EP4TP1 相关测试失败,请检查pr代码" + exit 1 +fi + + +echo "============================开始 EP4TP4 测试!============================" +sleep 5 +rm -rf log/* +rm -f core* +ipcrm --all=msg +xpu-smi +export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export BKCL_ENABLE_XDR=1 +export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4 +export BKCL_TRACE_TOPO=1 +export BKCL_PCIE_RING=1 +export XSHMEM_MODE=1 +export XSHMEM_QP_NUM_PER_RANK=32 +export BKCL_RDMA_VERBS=1 + +export enable_expert_parallel=1 +export enable_tensor_parallel=1 + +python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py +ep_exit_code=$? + +unset BKCL_ENABLE_XDR +unset BKCL_RDMA_NICS +unset BKCL_TRACE_TOPO +unset BKCL_PCIE_RING +unset XSHMEM_MODE +unset XSHMEM_QP_NUM_PER_RANK +unset BKCL_RDMA_VERBS +stop_processes + +if [ ${ep_exit_code} -ne 0 ]; then + echo "log/workerlog.0" + cat log/workerlog.0 + echo "EP4TP4 相关测试失败,请检查pr代码" exit 1 fi diff --git a/tests/ci_use/XPU_45T/run_ep.py b/tests/ci_use/XPU_45T/run_ep.py index e411396d6..e8a1e7197 100644 --- a/tests/ci_use/XPU_45T/run_ep.py +++ b/tests/ci_use/XPU_45T/run_ep.py @@ -1,31 +1,39 @@ import os import psutil +from paddleformers.trainer import strtobool from fastdeploy import LLM, SamplingParams def test_fd_ep(): """ """ - msg1 = [ {"role": "system", "content": ""}, {"role": "user", "content": "北京天安门广场在哪里?"}, ] messages = [msg1] + print(f"[INFO] messages: {messages}") # 采样参数 sampling_params = SamplingParams(top_p=0, max_tokens=500) # 模型路径与设备配置 - model = os.getenv("model_path", "/home/ERNIE-4.5-300B-A47B-Paddle") + model_root = os.getenv("MODEL_PATH", "/home") + model = f"{model_root}/ERNIE-4.5-300B-A47B-Paddle" xpu_visible_devices = os.getenv("XPU_VISIBLE_DEVICES", "0") xpu_device_num = len(xpu_visible_devices.split(",")) - enable_expert_parallel = True + enable_expert_parallel = strtobool(os.getenv("enable_expert_parallel", "1")) + enable_tensor_parallel = strtobool(os.getenv("enable_tensor_parallel", "0")) + print(f"enable_expert_parallel: {enable_expert_parallel}, enable_tensor_parallel: {enable_tensor_parallel}") if enable_expert_parallel: - tensor_parallel_size = 1 - data_parallel_size = xpu_device_num + if enable_tensor_parallel: + tensor_parallel_size = xpu_device_num + data_parallel_size = 1 + else: + tensor_parallel_size = 1 + data_parallel_size = xpu_device_num else: tensor_parallel_size = xpu_device_num data_parallel_size = 1 @@ -33,8 +41,6 @@ def test_fd_ep(): engine_worker_queue_port = [str(8023 + i * 10) for i in range(data_parallel_size)] engine_worker_queue_port = ",".join(engine_worker_queue_port) - print(f"[INFO] messages: {messages}") - llm = LLM( model=model, enable_expert_parallel=enable_expert_parallel,