diff --git a/custom_ops/gpu_ops/moe/fused_moe_helper.h b/custom_ops/gpu_ops/moe/fused_moe_helper.h index 22bf0f1f9..703a7c11f 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_helper.h +++ b/custom_ops/gpu_ops/moe/fused_moe_helper.h @@ -236,7 +236,7 @@ public: num_experts, k, stream); } - topk_gating_softmax_kernelLauncher::run( + topk_gating_softmax_kernelLauncher( gating_output, nullptr, expert_scales_float, softmax_out_, expert_for_source_row, source_rows_, softmax_max_prob, num_rows, num_experts, k, group_moe, stream); @@ -248,7 +248,7 @@ public: permuted_experts_, source_rows_, permuted_rows_, k * num_rows, false, stream); - initialize_moe_routing_kernelLauncher::run( + initialize_moe_routing_kernelLauncher( input_activations, permuted_data_, permuted_rows_, nullptr, nullptr, expanded_source_row_to_expanded_dest_row, num_rows, num_rows, hidden_size, k, stream); @@ -335,14 +335,14 @@ public: num_experts, down_proj_quant_args, stream); } - finalize_moe_routing_kernelLauncher::run( + finalize_moe_routing_kernelLauncher( fc2_result, output_, fc2_expert_biases, reinterpret_cast(expert_scales_float), expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, hidden_size, k, static_cast(1), norm_topk_prob, routed_scaling_factor, stream); } else { - finalize_moe_routing_kernelLauncher::run( + finalize_moe_routing_kernelLauncher( // fc2_result, fc1_out, output_, fc1_expert_biases, // fc2_expert_biases, diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 2abaae5dd..eeaecb716 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -1139,9 +1139,7 @@ void topk_gating_softmax_launcher_helper(const T* input, } template -struct topk_gating_softmax_kernelLauncher{ - -static void run(const T* input, +void topk_gating_softmax_kernelLauncher(const T* input, const T* gating_correction_bias, T* output, T* softmax, @@ -1221,7 +1219,6 @@ static void run(const T* input, } } } -}; // ========================== Permutation things // ======================================= @@ -1316,9 +1313,7 @@ __global__ void initialize_moe_routing_kernel( } template -struct initialize_moe_routing_kernelLauncher{ - -static void run( +void initialize_moe_routing_kernelLauncher( const T* unpermuted_input, OutT* permuted_output, const int* expanded_dest_row_to_expanded_source_row, @@ -1361,7 +1356,6 @@ static void run( num_rows * k); } } -}; // ============================== Infer GEMM sizes // ================================= @@ -1472,8 +1466,7 @@ __global__ void finalize_moe_routing_kernel( } template -struct finalize_moe_routing_kernelLauncher{ -static void run( +void finalize_moe_routing_kernelLauncher( const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, @@ -1505,5 +1498,4 @@ static void run( routed_scaling_factor, num_rows); } -}; } // namespace phi diff --git a/custom_ops/gpu_ops/moe/moe_dispatch.cu b/custom_ops/gpu_ops/moe/moe_dispatch.cu index d42b9f36b..8fa663c10 100644 --- a/custom_ops/gpu_ops/moe/moe_dispatch.cu +++ b/custom_ops/gpu_ops/moe/moe_dispatch.cu @@ -100,7 +100,7 @@ void MoeDispatchKernel( softmax_out_ = nullptr; } - topk_gating_softmax_kernelLauncher::run( + topk_gating_softmax_kernelLauncher( gating_output.data(), gating_correction_bias ? gating_correction_bias.get().data() : nullptr, @@ -114,13 +114,13 @@ void MoeDispatchKernel( if (w4a8_in_scale) { if (permute_input->dtype() == paddle::DataType::INT8) { - initialize_moe_routing_kernelLauncher::run( + initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), permute_indices_per_token->data(), num_rows, num_rows, hidden_size, moe_topk, stream); } else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { - initialize_moe_routing_kernelLauncher::run( + initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), @@ -128,7 +128,7 @@ void MoeDispatchKernel( hidden_size, moe_topk, stream); } } else { - initialize_moe_routing_kernelLauncher::run( + initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), nullptr, permute_indices_per_token->data(), num_rows, num_rows, diff --git a/custom_ops/gpu_ops/moe/moe_reduce.cu b/custom_ops/gpu_ops/moe/moe_reduce.cu index e10bf9121..e8532d5cd 100644 --- a/custom_ops/gpu_ops/moe/moe_reduce.cu +++ b/custom_ops/gpu_ops/moe/moe_reduce.cu @@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out, typedef typename traits_::data_t data_t; auto stream = ffn_out.stream(); - finalize_moe_routing_kernelLauncher::run( + finalize_moe_routing_kernelLauncher( ffn_out.data(), output->data(), down_proj_bias ? down_proj_bias->data() : nullptr, top_k_weight.data(), permute_indices_per_token.data(),