From 3a43dbf82dd97bc341e68647ccadae5109a74ec9 Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Thu, 23 Oct 2025 19:09:58 +0800 Subject: [PATCH] [XPU] merge apply_tp, ops support token_num = 0 (#4507) --- custom_ops/xpu_ops/src/ops/moe_ep_combine.cc | 27 +- custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc | 76 ++-- custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc | 6 + .../layers/backends/xpu/moe/fused_moe.py | 336 +++++++----------- 4 files changed, 191 insertions(+), 254 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc b/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc index 7ae0782b6..f10563e10 100644 --- a/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc +++ b/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc @@ -40,20 +40,21 @@ std::vector MoeEPCombineKernel( auto combined_out = paddle::empty( {recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place()); - const float* dequant_score = nullptr; - int ret = infer_ops::moe_ep_ffn_post_fusion( - xpu_ctx->x_context(), - reinterpret_cast(ffn_out.data()), - moe_index.data(), - reinterpret_cast(weights.data()), - dequant_score, - reinterpret_cast(combined_out.mutable_data()), - recv_token_num, - hidden_dim, - topk, - expand_token_num); - PD_CHECK(ret == 0); + if (recv_token_num > 0) { + int ret = infer_ops::moe_ep_ffn_post_fusion( + xpu_ctx->x_context(), + reinterpret_cast(ffn_out.data()), + moe_index.data(), + reinterpret_cast(weights.data()), + dequant_score, + reinterpret_cast(combined_out.mutable_data()), + recv_token_num, + hidden_dim, + topk, + expand_token_num); + PD_CHECK(ret == 0); + } return {combined_out}; } diff --git a/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc b/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc index 2690b8b13..c974073da 100644 --- a/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc +++ b/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc @@ -60,44 +60,48 @@ std::vector EPMoeExpertDispatchKernel( if (std::is_same::value) { permute_input = paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place); - auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe( - xpu_ctx->x_context(), - reinterpret_cast(input.data()), - topk_ids.data(), - input_scales.get_ptr()->data(), - nullptr, - reinterpret_cast(permute_input.data()), - const_cast(permute_indices_per_token.data()), - const_cast(expert_m.data()), - const_cast(recv_num_tokens_per_expert_list_cumsum.data()), - expand_input_scales.data(), - m, - n, - expert_num, - topk, - block_num, - token_nums_this_rank); - PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed"); + if (token_nums_this_rank > 0) { + auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe( + xpu_ctx->x_context(), + reinterpret_cast(input.data()), + topk_ids.data(), + input_scales.get_ptr()->data(), + nullptr, + reinterpret_cast(permute_input.data()), + const_cast(permute_indices_per_token.data()), + const_cast(expert_m.data()), + const_cast(recv_num_tokens_per_expert_list_cumsum.data()), + expand_input_scales.data(), + m, + n, + expert_num, + topk, + block_num, + token_nums_this_rank); + PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed"); + } } else { permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place); - auto ret = infer_ops::moe_ep_ffn_pre_sorted( - xpu_ctx->x_context(), - reinterpret_cast(input.data()), - topk_ids.data(), - nullptr, - reinterpret_cast(permute_input.data()), - const_cast(permute_indices_per_token.data()), - const_cast(expert_m.data()), - const_cast(recv_num_tokens_per_expert_list_cumsum.data()), - m, - n, - expert_num, - topk, - block_num, - ep_size, - ep_rank, - token_nums_this_rank); - PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed"); + if (token_nums_this_rank > 0) { + auto ret = infer_ops::moe_ep_ffn_pre_sorted( + xpu_ctx->x_context(), + reinterpret_cast(input.data()), + topk_ids.data(), + nullptr, + reinterpret_cast(permute_input.data()), + const_cast(permute_indices_per_token.data()), + const_cast(expert_m.data()), + const_cast(recv_num_tokens_per_expert_list_cumsum.data()), + m, + n, + expert_num, + topk, + block_num, + ep_size, + ep_rank, + token_nums_this_rank); + PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed"); + } } return {permute_input, permute_indices_per_token, diff --git a/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc b/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc index 0b064044b..860fd8503 100644 --- a/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc +++ b/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc @@ -441,6 +441,12 @@ std::vector MoeExpertFFN( const std::string& quant_method, const int hadamard_blocksize, const int valid_token_num) { + if (ffn_in.numel() == 0) { + paddle::Tensor ffn2_out = + paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16); + return {ffn2_out}; + } + const auto x_type = ffn_in.dtype(); const auto w_type = ffn1_weight.dtype(); diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index cc9b36629..e95e1b1a5 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -146,14 +146,14 @@ class XPUMoEMethod(MoEMethodBase): layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights) layer.down_proj_weight.set_value(stacked_down_proj_weights) - def apply_tp( + def apply_tp_fused_op( self, layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, ) -> paddle.Tensor: """ - XPU compute Fused MoE. + Apply TP Fused Op. """ from fastdeploy.model_executor.ops.xpu import xpu_moe_layer @@ -165,9 +165,9 @@ class XPUMoEMethod(MoEMethodBase): layer.down_proj_weight, None, # up_gate_proj bias None, # down_proj bias - getattr(layer, "up_gate_proj_weight_scale", None), - getattr(layer, "down_proj_weight_scale", None), - getattr(layer, "up_gate_proj_in_scale", None), + getattr(layer, self.added_scale_attrs[0], None), + getattr(layer, self.added_scale_attrs[1], None), + getattr(layer, self.added_in_scale_attrs[0], None), self.moe_quant_type, layer.top_k, False, # moe group, used in deepseek @@ -177,21 +177,103 @@ class XPUMoEMethod(MoEMethodBase): return fused_moe_out + def apply_tp_scatter_op( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Apply TP Scatter Op. + """ + gate_out = gate(x.cast("float32")) + topk_idx, topk_weights = moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, + ) + token_nums_per_expert_list = list(range(64)) # placeholder, not use + ( + permute_input, + permute_indices_per_token, + token_num_lod, + dst_weights, + ffn1_act_scale_per_token, + ) = ep_moe_expert_dispatch( + x, + topk_idx, + topk_weights, + getattr(layer, self.added_in_scale_attrs[0], None), + token_nums_per_expert_list, + x.shape[0] * layer.top_k, + self.moe_quant_type, + ) + + if not hasattr(layer, self.added_in_scale_attrs[0]): + ffn1_act_scale_per_token = None + ffn_out = self.compute_ffn( + layer, + permute_input, + token_num_lod, + x.shape[0] * layer.top_k, + ffn1_act_scale_per_token, + ) + + topk_weights_bf16 = topk_weights.astype("bfloat16") + tmp_ffn_out = ep_moe_expert_combine( + ffn_out, + permute_indices_per_token, + topk_weights_bf16, + permute_indices_per_token.shape[0], + ffn_out.shape[0], + ffn_out.shape[1], + permute_indices_per_token.shape[1], + ) + + if layer.reduce_results and layer.tp_size > 1: + tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out) + return tmp_ffn_out + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + apply tp + """ + if self.moe_quant_type in ["w16a16"]: + using_ep_moe_algo = False + elif self.moe_quant_type in ["w4a8"]: + using_ep_moe_algo = True + else: + using_ep_moe_algo = int(os.environ.get("USING_EP_MOE_ALGO", 0)) != 0 + print(f"using_ep_moe_algo: {using_ep_moe_algo}") + + if using_ep_moe_algo: + fused_moe_out = self.apply_tp_scatter_op(layer, x, gate) + else: + fused_moe_out = self.apply_tp_fused_op(layer, x, gate) + + return fused_moe_out + def compute_ffn( self, layer: nn.Layer, permute_input, token_num_lod, - valid_token_num=-1, - extra_ffn1_in_scale=None, + valid_token_num, + ffn1_act_scale_per_token=None, ): """ Calculate moe """ - # ffn1_in_scale = extra_ffn1_in_scale - moe_ffn1_scale = None - moe_ffn2_scale = None - + if self.moe_quant_type in ["w4a8"]: + hadamard_block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 128) + else: + hadamard_block_size = -1 ffn_out = moe_expert_ffn( permute_input, token_num_lod, @@ -199,14 +281,14 @@ class XPUMoEMethod(MoEMethodBase): getattr(layer, self.added_weight_attrs[1]), None, None, - moe_ffn1_scale, - moe_ffn2_scale, - getattr(layer, self.added_scale_attrs[0]), - getattr(layer, self.added_scale_attrs[1]), + ffn1_act_scale_per_token, + getattr(layer, self.added_in_scale_attrs[1], None), + getattr(layer, self.added_scale_attrs[0], None), + getattr(layer, self.added_scale_attrs[1], None), None, None, self.moe_quant_type, - -1, + hadamard_block_size, valid_token_num, ) return ffn_out @@ -245,45 +327,41 @@ class XPUMoEMethod(MoEMethodBase): token_all_num = sum(token_num_per_expert) # 4. Compute ffn - if token_all_num > 0: - moe_dispatch_scale = None - ( - permute_input, - permute_indices_per_token, - token_num_lod, - dst_weights, - ffn1_act_scale_per_token, - ) = ep_moe_expert_dispatch( - recv_x, - recv_topk_idx, - recv_topk_weights, - moe_dispatch_scale, - token_num_per_expert, - token_all_num, - self.moe_quant_type, - ) + moe_dispatch_scale = None + ( + permute_input, + permute_indices_per_token, + token_num_lod, + dst_weights, + ffn1_act_scale_per_token, + ) = ep_moe_expert_dispatch( + recv_x, + recv_topk_idx, + recv_topk_weights, + moe_dispatch_scale, + token_num_per_expert, + token_all_num, + self.moe_quant_type, + ) - ffn_out = self.compute_ffn( - layer, - permute_input, - token_num_lod, - token_all_num, - ) + ffn_out = self.compute_ffn( + layer, + permute_input, + token_num_lod, + token_all_num, + ) - # prmt back per rank - recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16") - tmp_ffn_out = ep_moe_expert_combine( - ffn_out, - permute_indices_per_token, - recv_topk_weights_bf16, - permute_indices_per_token.shape[0], - ffn_out.shape[0], - ffn_out.shape[1], - permute_indices_per_token.shape[1], - ) - - else: - tmp_ffn_out = paddle.empty(recv_x.shape, "bfloat16") + # prmt back per rank + recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16") + tmp_ffn_out = ep_moe_expert_combine( + ffn_out, + permute_indices_per_token, + recv_topk_weights_bf16, + permute_indices_per_token.shape[0], + ffn_out.shape[0], + ffn_out.shape[1], + permute_indices_per_token.shape[1], + ) # 5. EP combine handle = None @@ -395,98 +473,6 @@ class XPUWeightOnlyMoEMethod(XPUMoEMethod): quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) getattr(layer, scale_name).set_value(quanted_weight_scale) - def apply_tp( - self, - layer: nn.Layer, - x: paddle.Tensor, - gate: nn.Layer, - ) -> paddle.Tensor: - """ - XPU compute Fused MoE. - """ - USING_EP_MOE_ALGO = int(os.environ.get("USING_EP_MOE_ALGO", 0)) - if USING_EP_MOE_ALGO: - token_num = x.shape[0] - if token_num > 0: - gate_out = gate(x.cast("float32")) - topk_idx, topk_weights = moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, - ) - token_nums_per_expert_list = list(range(64)) # 填充做占位符 - ( - permute_input, - permute_indices_per_token, - token_num_lod, - dst_weights, - ffn1_act_scale_per_token, - ) = ep_moe_expert_dispatch( - x, - topk_idx, - topk_weights, - getattr(layer, "up_gate_proj_in_scale", None), - token_nums_per_expert_list, - x.shape[0] * layer.top_k, - self.moe_quant_type, - ) - - ffn_out = moe_expert_ffn( - permute_input, - token_num_lod, - layer.up_gate_proj_weight, - layer.down_proj_weight, - None, # moe_ffn1_bias - None, # moe_ffn2_bias - None, # ffn1 in scale - None, # ffn2 in scale - getattr(layer, "up_gate_proj_weight_scale", None), - getattr(layer, "down_proj_weight_scale", None), - None, # moe_ffn2_shift - None, # moe_ffn2_smooth - self.moe_quant_type, - -1, - x.shape[0] * layer.top_k, # token_all_num - ) - topk_weights_bf16 = topk_weights.astype("bfloat16") - tmp_ffn_out = ep_moe_expert_combine( - ffn_out, - permute_indices_per_token, - topk_weights_bf16, - permute_indices_per_token.shape[0], - ffn_out.shape[0], - ffn_out.shape[1], - permute_indices_per_token.shape[1], - ) - else: - tmp_ffn_out = paddle.empty(x.shape, x.dtype) - - if layer.reduce_results and layer.tp_size > 1: - tensor_model_parallel_all_reduce(tmp_ffn_out) - return tmp_ffn_out - else: - from fastdeploy.model_executor.ops.xpu import xpu_moe_layer - - fused_moe_out = xpu_moe_layer( - x, - gate.weight.transpose([1, 0]), - layer.gate_correction_bias, - layer.up_gate_proj_weight, - layer.down_proj_weight, - None, # up_gate_proj bias - None, # down_proj bias - (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), - (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), - (layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None), - self.moe_quant_type, - layer.top_k, - False, # moe group, used in deepseek - ) - if layer.reduce_results and layer.tp_size > 1: - tensor_model_parallel_all_reduce(fused_moe_out) - return fused_moe_out - class XPUW4A8MoEMethod(XPUMoEMethod): """ @@ -607,63 +593,3 @@ class XPUW4A8MoEMethod(XPUMoEMethod): for weight_scale_name in self.added_scale_attrs: getattr(layer, weight_scale_name).set_value(paddle.stack(scale_weight_map[weight_scale_name], axis=0)) - - def apply_tp( - self, - layer: nn.Layer, - x: paddle.Tensor, - gate: nn.Layer, - ) -> paddle.Tensor: - gate_out = gate(x.cast("float32")) - topk_idx, topk_weights = moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, - ) - token_nums_per_expert_list = list(range(64)) # 填充做占位符 - ( - permute_input, - permute_indices_per_token, - token_num_lod, - dst_weights, - ffn1_act_scale_per_token, - ) = ep_moe_expert_dispatch( - x, - topk_idx, - topk_weights, - getattr(layer, "up_gate_proj_in_scale", None), - token_nums_per_expert_list, - x.shape[0] * layer.top_k, - self.moe_quant_type, - ) - ffn_out = moe_expert_ffn( - permute_input, - token_num_lod, - layer.up_gate_proj_weight, - layer.down_proj_weight, - None, # moe_ffn1_bias - None, # moe_ffn2_bias - (ffn1_act_scale_per_token if hasattr(layer, "up_gate_proj_in_scale") else None), - getattr(layer, "down_proj_in_scale", None), - getattr(layer, "up_gate_proj_weight_scale", None), - getattr(layer, "down_proj_weight_scale", None), - None, # moe_ffn2_shift - None, # moe_ffn2_smooth - self.moe_quant_type, - getattr(layer.moe_quant_config, "hadamard_block_size", 128), - x.shape[0] * layer.top_k, # token_all_num - ) - topk_weights_bf16 = topk_weights.astype("bfloat16") - tmp_ffn_out = ep_moe_expert_combine( - ffn_out, - permute_indices_per_token, - topk_weights_bf16, - permute_indices_per_token.shape[0], - ffn_out.shape[0], - ffn_out.shape[1], - permute_indices_per_token.shape[1], - ) - if layer.reduce_results and layer.tp_size > 1: - tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out) - return tmp_ffn_out