mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
c++ code format (#4527)
This commit is contained in:
@@ -12,12 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_op.h"
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
__global__ void compute_total_rows_before_expert_kernel(
|
||||
int* sorted_experts,
|
||||
@@ -43,7 +42,10 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
@@ -63,27 +65,26 @@ void FusedMoeKernel(const paddle::Tensor& input,
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
auto moe_compute =
|
||||
McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
|
||||
moe_compute.computeFFN(
|
||||
&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
moe_compute.computeFFN(&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
@@ -102,19 +103,22 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16,
|
||||
maca_bfloat16,
|
||||
int8_t,
|
||||
maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
@@ -161,7 +165,6 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(fused_expert_moe)
|
||||
.Inputs({"input",
|
||||
"gate_weight",
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
static const float HALF_FLT_MAX = 65504.F;
|
||||
|
||||
@@ -19,9 +19,9 @@
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h" // BUILD_MARK
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "mctlass/numeric_conversion.h" // BUILD_MARK
|
||||
// Ignore mctlass warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@@ -35,8 +35,8 @@
|
||||
#define WARP_SIZE 32
|
||||
|
||||
struct GpuLaunchConfig {
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
@@ -82,7 +82,6 @@ __launch_bounds__(TPB) __global__
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
@@ -603,7 +602,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
@@ -646,14 +645,8 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
softmax, output, indices, source_row, num_experts, k, num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +1,71 @@
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
#include "fused_moe_helper.h"
|
||||
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC>
|
||||
void mc_grouped_gemm_basic_kernel(
|
||||
const ElementA* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA* ptrScale,
|
||||
const ElementA* ptrBias,
|
||||
ElementC* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB *ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA *ptrScale,
|
||||
const ElementA *ptrBias,
|
||||
ElementC *ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
int* ptrMNumTilesInd;
|
||||
mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
int *ptrMNumTilesInd;
|
||||
mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
@@ -55,44 +74,81 @@ void mc_grouped_gemm_basic_kernel(
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
||||
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
mctlassExEpilogueType epilogue_type =
|
||||
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale, sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type, sizeof(mctlassExDataType));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale,
|
||||
sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias, sizeof(ptrBias));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias,
|
||||
sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type, sizeof(mctlassExDataType));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type, sizeof(mctlassExEpilogueType));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type,
|
||||
sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo =
|
||||
mctlassExContiguousGroupedGemmAlgo_t::
|
||||
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
||||
mctlassExContiguousGroupedDescCreate(&contiguous_group_desc,
|
||||
ptrSegInd,
|
||||
nullptr,
|
||||
ptrMNumTilesInd,
|
||||
1);
|
||||
mctlassExContiguousGroupedDescCreate(
|
||||
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
|
||||
int blocksizeM;
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream);
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
&blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
contiguous_group_desc,
|
||||
numExperts,
|
||||
blocksizeM,
|
||||
stream);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
|
||||
ptrA, matLayoutA,
|
||||
ptrB, matLayoutB,
|
||||
ptrC, matLayoutC,
|
||||
mctlassExContiguousGroupedGemmBasic(handle,
|
||||
mctlass_desc,
|
||||
ptrA,
|
||||
matLayoutA,
|
||||
ptrB,
|
||||
matLayoutB,
|
||||
ptrC,
|
||||
matLayoutC,
|
||||
contiguous_group_desc,
|
||||
&algo, nullptr, 0, stream);
|
||||
&algo,
|
||||
nullptr,
|
||||
0,
|
||||
stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
@@ -103,312 +159,312 @@ void mc_grouped_gemm_basic_kernel(
|
||||
mcFreeAsync(ptrMNumTilesInd, stream);
|
||||
}
|
||||
|
||||
template<typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
template <typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
class McMoeHelper {
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {}
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
return total_ws_bytes;
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
float *softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
float *softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
@@ -24,7 +23,6 @@
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
@@ -128,7 +126,6 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
false,
|
||||
stream);
|
||||
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
@@ -140,16 +137,13 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
moe_topk,
|
||||
stream);
|
||||
|
||||
|
||||
compute_total_rows_before_expert(
|
||||
permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
stream);
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
@@ -184,7 +178,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
@@ -226,7 +219,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
top_k_indices};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gating_output_shape,
|
||||
@@ -260,7 +252,6 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_dispatch)
|
||||
.Inputs({"input", "gating_output"})
|
||||
.Outputs({"permute_input",
|
||||
|
||||
@@ -14,19 +14,22 @@
|
||||
|
||||
// BUILD_MARK
|
||||
#pragma once
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -37,61 +40,65 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
auto input_type = permute_input.dtype();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
const int expanded_active_expert_rows =
|
||||
permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
||||
{expanded_active_expert_rows, inter_size}, input_type, place);
|
||||
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales = const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales =
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
reinterpret_cast<const ElementA*>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC*>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
// swiglu
|
||||
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
auto fc2_expert_scales = const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
auto fc2_expert_scales =
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(ffn_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
reinterpret_cast<const ElementA*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC*>(ffn_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
@@ -109,15 +116,18 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16,
|
||||
maca_bfloat16,
|
||||
int8_t,
|
||||
maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
|
||||
@@ -12,12 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
@@ -52,7 +51,6 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
@@ -106,7 +104,6 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
return {output};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t>& ffn_out_shape,
|
||||
const std::vector<int64_t>& top_k_weight_shape,
|
||||
@@ -129,7 +126,6 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_reduce)
|
||||
.Inputs({"ffn_out",
|
||||
"top_k_weight",
|
||||
|
||||
Reference in New Issue
Block a user