mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
clean code (#4020)
This commit is contained in:
@@ -236,7 +236,7 @@ public:
|
||||
num_experts, k, stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
topk_gating_softmax_kernelLauncher<float, int>(
|
||||
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||
num_experts, k, group_moe, stream);
|
||||
@@ -248,7 +248,7 @@ public:
|
||||
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||
false, stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher<T>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||
hidden_size, k, stream);
|
||||
@@ -335,14 +335,14 @@ public:
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result, output_, fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out, output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
|
||||
@@ -1139,9 +1139,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
struct topk_gating_softmax_kernelLauncher{
|
||||
|
||||
static void run(const T* input,
|
||||
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
const T* gating_correction_bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
@@ -1221,7 +1219,6 @@ static void run(const T* input,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
@@ -1316,9 +1313,7 @@ __global__ void initialize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T, typename OutT = T>
|
||||
struct initialize_moe_routing_kernelLauncher{
|
||||
|
||||
static void run(
|
||||
void initialize_moe_routing_kernelLauncher(
|
||||
const T* unpermuted_input,
|
||||
OutT* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
@@ -1361,7 +1356,6 @@ static void run(
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
@@ -1472,8 +1466,7 @@ __global__ void finalize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct finalize_moe_routing_kernelLauncher{
|
||||
static void run(
|
||||
void finalize_moe_routing_kernelLauncher(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
@@ -1505,5 +1498,4 @@ static void run(
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
};
|
||||
} // namespace phi
|
||||
|
||||
@@ -100,7 +100,7 @@ void MoeDispatchKernel(
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
topk_gating_softmax_kernelLauncher(
|
||||
gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||
: nullptr,
|
||||
@@ -114,13 +114,13 @@ void MoeDispatchKernel(
|
||||
|
||||
if (w4a8_in_scale) {
|
||||
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
@@ -128,7 +128,7 @@ void MoeDispatchKernel(
|
||||
hidden_size, moe_topk, stream);
|
||||
}
|
||||
} else {
|
||||
initialize_moe_routing_kernelLauncher<data_t>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), nullptr,
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
|
||||
@@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher<data_t>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(), output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
|
||||
|
||||
Reference in New Issue
Block a user