mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
This reverts commit 93fcf7e4ec.
This commit is contained in:
@@ -25,8 +25,7 @@ template <typename T, int VecSize>
|
||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k) {
|
||||
const int num_experts, const int k) {
|
||||
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (moe_token_index >= num_rows) {
|
||||
@@ -45,8 +44,7 @@ template <typename T>
|
||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
const int num_experts, const int k,
|
||||
cudaStream_t stream) {
|
||||
const int blocks = num_rows * k / 512 + 1;
|
||||
const int threads = 512;
|
||||
@@ -54,35 +52,26 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||
}
|
||||
|
||||
template <typename T, typename NvType>
|
||||
class MoeHelper {
|
||||
public:
|
||||
using Fp16Traits =
|
||||
cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
||||
using Int8Traits =
|
||||
cutlass::WintQuantTraits<NvType,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
||||
using Int4Traits =
|
||||
cutlass::WintQuantTraits<NvType,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
||||
template <typename T, typename NvType> class MoeHelper {
|
||||
public:
|
||||
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
||||
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
||||
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
||||
|
||||
MoeHelper(const std::string gemm_method,
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
||||
int layernum = 0)
|
||||
: gemm_method_(gemm_method),
|
||||
fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
||||
MoeHelper(
|
||||
const std::string gemm_method,
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
||||
int layernum = 0)
|
||||
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
||||
int8_moe_gemm_runner_(int8_moe_gemm_runner),
|
||||
int4_moe_gemm_runner_(int4_moe_gemm_runner),
|
||||
layernum_(layernum) {}
|
||||
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size,
|
||||
const int64_t inter_size, const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
@@ -93,10 +82,10 @@ class MoeHelper {
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
||||
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
@@ -111,8 +100,8 @@ class MoeHelper {
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
@@ -126,27 +115,20 @@ class MoeHelper {
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void ComputeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *up_gate_proj_weight,
|
||||
const paddle::Tensor *up_gate_proj_scale,
|
||||
const paddle::Tensor *up_gate_proj_bias,
|
||||
const paddle::Tensor *down_proj_weight,
|
||||
const paddle::Tensor *down_proj_scale,
|
||||
const paddle::Tensor *down_proj_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
void
|
||||
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *up_gate_proj_weight,
|
||||
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
|
||||
const paddle::Tensor *down_proj_weight,
|
||||
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
|
||||
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
|
||||
const bool group_moe, const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases =
|
||||
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases =
|
||||
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
@@ -166,8 +148,7 @@ class MoeHelper {
|
||||
const int64_t hidden_size = up_gate_proj_dims[1];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim =
|
||||
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||
} else {
|
||||
inter_dim = up_gate_proj_dims[2];
|
||||
}
|
||||
@@ -251,79 +232,44 @@ class MoeHelper {
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
moe_token_type_ids_out, num_rows,
|
||||
num_experts, k, stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
|
||||
nullptr,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
topk_gating_softmax_kernelLauncher<float, int>(
|
||||
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||
num_experts, k, group_moe, stream);
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row,
|
||||
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||
false, stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
nullptr,
|
||||
nullptr,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
nullptr,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||
hidden_size, k, stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
expanded_active_expert_rows, num_experts,
|
||||
total_rows_before_expert_, stream);
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
typename Int8Traits::Arguments up_gate_proj_quant_args;
|
||||
int8_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const uint8_t *>(
|
||||
up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments up_gate_proj_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
@@ -332,33 +278,20 @@ class MoeHelper {
|
||||
up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments up_gate_proj_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
}
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
@@ -376,15 +309,10 @@ class MoeHelper {
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||
reinterpret_cast<NvType *>(fc2_result),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments down_proj_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm(
|
||||
@@ -392,66 +320,40 @@ class MoeHelper {
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
down_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||
reinterpret_cast<NvType *>(fc2_result),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments down_proj_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<NvType *>(fc2_result),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
fc2_result, output_, fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
fc1_out, output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, inter_size, k, static_cast<int>(0), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
|
||||
@@ -460,4 +362,4 @@ class MoeHelper {
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
Reference in New Issue
Block a user