mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* [Metax] refactor moe and flash attention backend --------- Co-authored-by: zhangchenyi_dl <16219492+zhangchenyidl@user.noreply.gitee.com>
392 lines
14 KiB
C++
392 lines
14 KiB
C++
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#pragma once
|
|
|
|
#include "fused_moe_gemm_kernels.h"
|
|
#include "fused_moe_imp_op.h"
|
|
#include "fused_moe_op.h"
|
|
#include "mctlass/numeric_conversion.h"
|
|
#include "mctlassEx/mctlassEx.h"
|
|
|
|
namespace phi {
|
|
|
|
template <typename T, int VecSize>
|
|
__global__ void moe_token_type_ids_kernel(T* gating_output,
|
|
const int* moe_token_type_ids_out,
|
|
const int num_rows,
|
|
const int num_experts,
|
|
const int k) {
|
|
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
if (moe_token_index >= num_rows) {
|
|
return;
|
|
}
|
|
|
|
gating_output[moe_token_index * 2] =
|
|
gating_output[moe_token_index * 2] +
|
|
(moe_token_type_ids_out[moe_token_index]) * -1e10;
|
|
gating_output[moe_token_index * 2 + 1] =
|
|
gating_output[moe_token_index * 2 + 1] +
|
|
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
|
|
}
|
|
|
|
template <typename T>
|
|
void moe_token_type_ids_kernelLauncher(T* gating_output,
|
|
const int* moe_token_type_ids_out,
|
|
const int num_rows,
|
|
const int num_experts,
|
|
const int k,
|
|
cudaStream_t stream) {
|
|
const int blocks = num_rows * k / 512 + 1;
|
|
const int threads = 512;
|
|
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
|
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
|
}
|
|
|
|
template <typename T, typename MacaType>
|
|
class McMoeHelper {
|
|
public:
|
|
McMoeHelper(const std::string gemm_method,
|
|
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner)
|
|
: gemm_method_(gemm_method),
|
|
int8_moe_gemm_runner_(int8_moe_gemm_runner) {}
|
|
|
|
// -------- getWorkspaceSize -------- //
|
|
template <typename KeyT>
|
|
size_t getWorkspaceSize(const int64_t num_rows,
|
|
const int64_t hidden_size,
|
|
const int64_t inter_size,
|
|
const int64_t num_experts,
|
|
const int64_t k) {
|
|
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
|
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
|
const size_t padded_experts = AlignTo16(num_experts);
|
|
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
|
// softmax output, permuted_rows and permuted_experts have moved to outside
|
|
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
|
// FfnLayer forward.
|
|
size_t total_ws_bytes =
|
|
5 * num_moe_inputs *
|
|
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
|
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
|
total_ws_bytes +=
|
|
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
|
|
|
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
|
const size_t sorter_ws_size_bytes =
|
|
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
|
sorter_.update_num_experts(num_experts);
|
|
|
|
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
|
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
|
int64_t remaining_bytes =
|
|
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
|
bytes_for_intermediate_and_sorting += remaining_bytes;
|
|
}
|
|
|
|
total_ws_bytes +=
|
|
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
|
// sorting workspace
|
|
|
|
int64_t num_softmax_outs = 0;
|
|
const bool is_pow_2 =
|
|
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
|
if (!is_pow_2 || num_experts > 256) {
|
|
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
|
}
|
|
|
|
total_ws_bytes += num_softmax_outs * sizeof(float);
|
|
|
|
return total_ws_bytes;
|
|
}
|
|
|
|
void computeFFN(const paddle::Tensor* input,
|
|
const paddle::Tensor* gate_weight,
|
|
const paddle::Tensor* up_gate_proj_weight,
|
|
const paddle::Tensor* up_gate_proj_scale,
|
|
const paddle::Tensor* up_gate_proj_bias,
|
|
const paddle::Tensor* down_proj_weight,
|
|
const paddle::Tensor* down_proj_scale,
|
|
const paddle::Tensor* down_proj_bias,
|
|
const paddle::Tensor* moe_token_type_ids,
|
|
const int moe_topk,
|
|
const bool group_moe,
|
|
const bool norm_topk_prob,
|
|
const float routed_scaling_factor,
|
|
const std::string moe_type,
|
|
paddle::Tensor* output) {
|
|
auto* input_activations = input->data<T>();
|
|
auto* gating_weights = gate_weight->data<float>();
|
|
const T* fc1_expert_biases =
|
|
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
|
const T* fc2_expert_biases =
|
|
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
|
|
|
auto* output_ = output->data<T>();
|
|
auto stream = input->stream();
|
|
auto place = input->place();
|
|
auto input_type = input->dtype();
|
|
|
|
auto input_dims = input->dims();
|
|
auto up_gate_proj_dims = up_gate_proj_weight->dims();
|
|
int64_t token_num = 0;
|
|
if (input_dims.size() == 3) {
|
|
token_num = input_dims[0] * input_dims[1];
|
|
} else {
|
|
token_num = input_dims[0];
|
|
}
|
|
const int64_t num_rows = token_num;
|
|
|
|
const int64_t hidden_size = up_gate_proj_dims[2];
|
|
int64_t inter_dim = 0;
|
|
if (moe_type == "qkv") {
|
|
inter_dim =
|
|
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
|
} else {
|
|
inter_dim = up_gate_proj_dims[1];
|
|
}
|
|
|
|
// if (gemm_method_ == "weight_only_int4") {
|
|
// inter_dim = inter_dim * 2;
|
|
// }
|
|
|
|
const int64_t inter_size = inter_dim;
|
|
const int64_t num_experts = up_gate_proj_dims[0];
|
|
const int64_t k = moe_topk;
|
|
|
|
int64_t bytes =
|
|
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
|
|
|
// Pointers
|
|
int* expert_for_source_row;
|
|
int* source_rows_;
|
|
int* permuted_rows_;
|
|
int* permuted_experts_;
|
|
int* expanded_source_row_to_expanded_dest_row;
|
|
|
|
T* permuted_data_;
|
|
int32_t* total_rows_before_expert_;
|
|
T* fc1_result_;
|
|
float* softmax_out_;
|
|
|
|
paddle::Tensor ws_ptr_tensor =
|
|
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
|
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
|
|
|
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
|
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
|
const int64_t padded_experts = AlignTo16(num_experts);
|
|
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
|
|
|
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
|
|
source_rows_ = expert_for_source_row + num_moe_inputs;
|
|
permuted_rows_ = source_rows_ + num_moe_inputs;
|
|
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
|
expanded_source_row_to_expanded_dest_row =
|
|
permuted_experts_ + num_moe_inputs;
|
|
permuted_data_ = reinterpret_cast<T*>(
|
|
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
|
total_rows_before_expert_ =
|
|
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
|
|
fc1_result_ =
|
|
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
|
|
|
|
const bool is_pow_2 =
|
|
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
|
if (!is_pow_2 || num_experts > 256) {
|
|
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
|
|
} else {
|
|
softmax_out_ = nullptr;
|
|
}
|
|
|
|
paddle::Tensor expert_scales_float_tensor =
|
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
|
float* expert_scales_float = expert_scales_float_tensor.data<float>();
|
|
|
|
float* softmax_max_prob = nullptr;
|
|
if (group_moe) {
|
|
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
|
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
|
// (TODO: check fill success ?)
|
|
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
|
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
|
}
|
|
|
|
paddle::Tensor fc1_out_tensor =
|
|
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
|
T* fc1_out = fc1_out_tensor.data<T>();
|
|
|
|
auto input_cast_tensor =
|
|
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
|
auto gate_tensor =
|
|
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
|
float* gating_output = gate_tensor.data<float>();
|
|
|
|
if (moe_token_type_ids) {
|
|
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
|
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
|
moe_token_type_ids_out,
|
|
num_rows,
|
|
num_experts,
|
|
k,
|
|
stream);
|
|
}
|
|
|
|
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
|
|
nullptr,
|
|
expert_scales_float,
|
|
softmax_out_,
|
|
expert_for_source_row,
|
|
source_rows_,
|
|
softmax_max_prob,
|
|
num_rows,
|
|
num_experts,
|
|
k,
|
|
group_moe,
|
|
stream);
|
|
|
|
const int64_t sorter_ws_size_bytes =
|
|
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
|
|
|
sorter_.run(fc1_result_,
|
|
sorter_ws_size_bytes,
|
|
expert_for_source_row,
|
|
permuted_experts_,
|
|
source_rows_,
|
|
permuted_rows_,
|
|
k * num_rows,
|
|
false,
|
|
stream);
|
|
|
|
initialize_moe_routing_kernelLauncher(
|
|
input_activations,
|
|
permuted_data_,
|
|
permuted_rows_,
|
|
nullptr,
|
|
nullptr,
|
|
expanded_source_row_to_expanded_dest_row,
|
|
num_rows,
|
|
num_rows,
|
|
hidden_size,
|
|
k,
|
|
stream);
|
|
|
|
const int64_t expanded_active_expert_rows = k * num_rows;
|
|
|
|
compute_total_rows_before_expert(permuted_experts_,
|
|
expanded_active_expert_rows,
|
|
num_experts,
|
|
total_rows_before_expert_,
|
|
stream);
|
|
|
|
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
|
mctlassExOrder_t column_major =
|
|
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
|
auto m_num_tile =
|
|
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
|
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
|
|
|
if (gemm_method_ == "weight_only_int8") {
|
|
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
|
|
reinterpret_cast<const MacaType*>(permuted_data_),
|
|
row_major,
|
|
reinterpret_cast<const int8_t*>(up_gate_proj_weight->data<int8_t>()),
|
|
column_major,
|
|
reinterpret_cast<const MacaType*>(up_gate_proj_scale->data<T>()),
|
|
reinterpret_cast<const MacaType*>(fc1_expert_biases),
|
|
reinterpret_cast<MacaType*>(fc1_out),
|
|
row_major,
|
|
total_rows_before_expert_,
|
|
m_num_tile_ptr,
|
|
num_experts,
|
|
expanded_active_expert_rows,
|
|
inter_size,
|
|
hidden_size,
|
|
stream);
|
|
} else {
|
|
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
|
|
}
|
|
|
|
if (moe_type == "ffn") {
|
|
auto act_out_tensor =
|
|
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
|
auto act_out = act_out_tensor.data<T>();
|
|
|
|
paddle::Tensor fc2_output_tensor =
|
|
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
|
T* fc2_result = fc2_output_tensor.data<T>();
|
|
|
|
if (gemm_method_ == "weight_only_int8") {
|
|
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
|
|
reinterpret_cast<const MacaType*>(act_out),
|
|
row_major,
|
|
reinterpret_cast<const int8_t*>(down_proj_weight->data<int8_t>()),
|
|
column_major,
|
|
reinterpret_cast<const MacaType*>(down_proj_scale->data<T>()),
|
|
nullptr,
|
|
reinterpret_cast<MacaType*>(fc2_result),
|
|
row_major,
|
|
total_rows_before_expert_,
|
|
m_num_tile_ptr,
|
|
num_experts,
|
|
expanded_active_expert_rows,
|
|
hidden_size,
|
|
inter_size / 2,
|
|
stream);
|
|
} else {
|
|
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
|
|
}
|
|
|
|
finalize_moe_routing_kernelLauncher(
|
|
fc2_result,
|
|
output_,
|
|
fc2_expert_biases,
|
|
reinterpret_cast<float*>(expert_scales_float),
|
|
expanded_source_row_to_expanded_dest_row,
|
|
expert_for_source_row,
|
|
num_rows,
|
|
hidden_size,
|
|
k,
|
|
static_cast<int>(1),
|
|
norm_topk_prob,
|
|
routed_scaling_factor,
|
|
stream);
|
|
} else {
|
|
finalize_moe_routing_kernelLauncher(
|
|
// fc2_result,
|
|
fc1_out,
|
|
output_,
|
|
fc1_expert_biases, // fc2_expert_biases,
|
|
reinterpret_cast<float*>(expert_scales_float),
|
|
expanded_source_row_to_expanded_dest_row,
|
|
expert_for_source_row,
|
|
num_rows,
|
|
inter_size,
|
|
k,
|
|
static_cast<int>(0),
|
|
norm_topk_prob,
|
|
routed_scaling_factor,
|
|
stream);
|
|
}
|
|
}
|
|
|
|
private:
|
|
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner_;
|
|
std::string gemm_method_;
|
|
CubKeyValueSorter sorter_;
|
|
};
|
|
|
|
} // namespace phi
|