clean code (#4020)

This commit is contained in:
周周周
2025-09-10 10:56:15 +08:00
committed by GitHub
parent f078a959b6
commit dbab579299
4 changed files with 12 additions and 20 deletions

View File

@@ -236,7 +236,7 @@ public:
num_experts, k, stream);
}
topk_gating_softmax_kernelLauncher<float, int>::run(
topk_gating_softmax_kernelLauncher<float, int>(
gating_output, nullptr, expert_scales_float, softmax_out_,
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
num_experts, k, group_moe, stream);
@@ -248,7 +248,7 @@ public:
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
false, stream);
initialize_moe_routing_kernelLauncher<T>::run(
initialize_moe_routing_kernelLauncher(
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
hidden_size, k, stream);
@@ -335,14 +335,14 @@ public:
num_experts, down_proj_quant_args, stream);
}
finalize_moe_routing_kernelLauncher<T>::run(
finalize_moe_routing_kernelLauncher(
fc2_result, output_, fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
routed_scaling_factor, stream);
} else {
finalize_moe_routing_kernelLauncher<T>::run(
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out, output_,
fc1_expert_biases, // fc2_expert_biases,

View File

@@ -1139,9 +1139,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
}
template <typename T, typename IdxT = int>
struct topk_gating_softmax_kernelLauncher{
static void run(const T* input,
void topk_gating_softmax_kernelLauncher(const T* input,
const T* gating_correction_bias,
T* output,
T* softmax,
@@ -1221,7 +1219,6 @@ static void run(const T* input,
}
}
}
};
// ========================== Permutation things
// =======================================
@@ -1316,9 +1313,7 @@ __global__ void initialize_moe_routing_kernel(
}
template <typename T, typename OutT = T>
struct initialize_moe_routing_kernelLauncher{
static void run(
void initialize_moe_routing_kernelLauncher(
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
@@ -1361,7 +1356,6 @@ static void run(
num_rows * k);
}
}
};
// ============================== Infer GEMM sizes
// =================================
@@ -1472,8 +1466,7 @@ __global__ void finalize_moe_routing_kernel(
}
template <typename T>
struct finalize_moe_routing_kernelLauncher{
static void run(
void finalize_moe_routing_kernelLauncher(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
@@ -1505,5 +1498,4 @@ static void run(
routed_scaling_factor,
num_rows);
}
};
} // namespace phi

View File

@@ -100,7 +100,7 @@ void MoeDispatchKernel(
softmax_out_ = nullptr;
}
topk_gating_softmax_kernelLauncher<float, int>::run(
topk_gating_softmax_kernelLauncher(
gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>()
: nullptr,
@@ -114,13 +114,13 @@ void MoeDispatchKernel(
if (w4a8_in_scale) {
if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
initialize_moe_routing_kernelLauncher(
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
initialize_moe_routing_kernelLauncher(
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
permuted_rows_, expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
@@ -128,7 +128,7 @@ void MoeDispatchKernel(
hidden_size, moe_topk, stream);
}
} else {
initialize_moe_routing_kernelLauncher<data_t>::run(
initialize_moe_routing_kernelLauncher(
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), nullptr,
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,

View File

@@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
typedef typename traits_::data_t data_t;
auto stream = ffn_out.stream();
finalize_moe_routing_kernelLauncher<data_t>::run(
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(), output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),