diff --git a/custom_ops/gpu_ops/moe/moe_topk_select.cu b/custom_ops/gpu_ops/moe/moe_topk_select.cu index bbdaabdf2..821d54bce 100644 --- a/custom_ops/gpu_ops/moe/moe_topk_select.cu +++ b/custom_ops/gpu_ops/moe/moe_topk_select.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -39,20 +38,35 @@ void moe_topk_select_kernel(const T* input, const int64_t k, cudaStream_t stream, const bool apply_norm_weight = false, - const bool enable_softmax_top_k_fused = false - ) { + const bool enable_softmax_top_k_fused = false) { static constexpr int WARPS_PER_TB = 4; - #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ - case N: { \ - if (apply_norm_weight) { \ - topk_gating_softmax_launcher_helper( \ - input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \ - } else { \ - topk_gating_softmax_launcher_helper( \ - input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \ - } \ - break; \ +#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ + case N: { \ + if (apply_norm_weight) { \ + topk_gating_softmax_launcher_helper( \ + input, \ + bias, \ + output, \ + indices, \ + source_row, \ + num_rows, \ + num_experts, \ + k, \ + stream); \ + } else { \ + topk_gating_softmax_launcher_helper( \ + input, \ + bias, \ + output, \ + indices, \ + source_row, \ + num_rows, \ + num_experts, \ + k, \ + stream); \ + } \ + break; \ } switch (num_experts) { LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) @@ -68,56 +82,56 @@ void moe_topk_select_kernel(const T* input, static constexpr int TPB = 256; const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); if (!enable_softmax_top_k_fused) { - moe_softmax<<>>( - input, softmax, num_experts, num_rows); - if (apply_norm_weight) { - moe_top_k - <<>>(softmax, - bias, - output, - indices, - source_row, - num_experts, - k, - num_rows); - } else { - moe_top_k - <<>>(softmax, - bias, - output, - indices, - source_row, - num_experts, - k, - num_rows); - } - cudaGetLastError(); + moe_softmax<<>>( + input, softmax, num_experts, num_rows); + if (apply_norm_weight) { + moe_top_k + <<>>( + softmax, + bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); + } else { + moe_top_k + <<>>(softmax, + bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); + } + cudaGetLastError(); + } else { + assert(k <= TPB); + if (apply_norm_weight) { + moe_softmax_top_k_fused + <<>>( + input, + bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); + } else { + moe_softmax_top_k_fused + <<>>(input, + bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); + } } - else { - assert(k<=TPB); - if (apply_norm_weight) { - moe_softmax_top_k_fused - <<>>(input, - bias, - output, - indices, - source_row, - num_experts, - k, - num_rows); - } else { - moe_softmax_top_k_fused - <<>>(input, - bias, - output, - indices, - source_row, - num_experts, - k, - num_rows); - } - } - } } } @@ -146,6 +160,13 @@ std::vector MoETopKSelectKernel( auto topk_weights = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + // NOTE(sunxin): Avoid "invalid configuration argument" error caused by empty + // tensors. + if (gating_dims[0] == 0) { + cudaGetLastError(); + return {topk_ids, topk_weights}; + } + const int num_moe_inputs = AlignTo16(num_rows * moe_topk); const int bytes = num_moe_inputs * sizeof(int); @@ -213,8 +234,7 @@ std::vector> MoETopKSelectKernelInferShape( } const int num_rows = token_rows; - return {{num_rows, moe_topk}, - {num_rows, moe_topk}}; + return {{num_rows, moe_topk}, {num_rows, moe_topk}}; } std::vector MoETopKSelectKernelInferDtype( @@ -223,16 +243,15 @@ std::vector MoETopKSelectKernelInferDtype( const int moe_topk, const bool apply_norm_weight, const bool enable_softmax_top_k_fused) { - return {paddle::DataType::INT64, - paddle::DataType::FLOAT32}; + return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; } - PD_BUILD_STATIC_OP(moe_topk_select) .Inputs({"gating_logits", paddle::Optional("bias")}) - .Outputs({"topk_ids", - "topk_weights"}) - .Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool"}) + .Outputs({"topk_ids", "topk_weights"}) + .Attrs({"moe_topk:int", + "apply_norm_weight:bool", + "enable_softmax_top_k_fused:bool"}) .SetKernelFn(PD_KERNEL(MoETopKSelectKernel)) .SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype));