mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support fusedmoe on Blackwell (#5325)
* update sm100 * fix * fix style
This commit is contained in:
@@ -277,7 +277,8 @@ struct MoeFCGemm {
|
||||
code_scale(const_cast<float*>(code_scale)),
|
||||
code_zp(const_cast<float*>(code_zp)),
|
||||
host_problem_sizes(nullptr) {
|
||||
if (quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
|
||||
if (quant_method != WintQuantMethod::kNone ||
|
||||
platform::is_same<uint8_t, ElementB>::value ||
|
||||
platform::is_same<uint4b_t, ElementB>::value) {
|
||||
assert(weight_scales);
|
||||
}
|
||||
@@ -380,7 +381,8 @@ struct MoeFCGemm {
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
if (args.quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
|
||||
if (args.quant_method != WintQuantMethod::kNone ||
|
||||
platform::is_same<uint8_t, ElementB>::value ||
|
||||
platform::is_same<uint4b_t, ElementB>::value) {
|
||||
if (args.weight_scales == nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
@@ -416,7 +418,6 @@ struct MoeFCGemm {
|
||||
|
||||
template <typename dummy>
|
||||
struct KernelRunner<true, dummy> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
@@ -471,9 +472,13 @@ struct MoeFCGemm {
|
||||
int64_t rows_to_jump = 0;
|
||||
|
||||
if (params.problem_visitor.total_rows < 0) {
|
||||
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
rows_to_jump = problem_idx == 0
|
||||
? 0
|
||||
: params.problem_visitor
|
||||
.last_row_for_problem[problem_idx - 1];
|
||||
} else {
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows /
|
||||
params.problem_visitor.problem_count);
|
||||
}
|
||||
|
||||
// begin address offset for A for current tile
|
||||
@@ -496,11 +501,13 @@ struct MoeFCGemm {
|
||||
0,
|
||||
};
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
// the begin threadblock_offset of B, which holds the same column id
|
||||
// with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
// the begin threadblock_offset of scale, which holds the same column id
|
||||
// with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
@@ -628,7 +635,7 @@ struct MoeFCGemm {
|
||||
static constexpr bool compile_needed =
|
||||
platform::is_same<KernelArch, arch::Sm75>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010)
|
||||
static constexpr bool compile_needed =
|
||||
platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
@@ -649,9 +656,17 @@ template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to //
|
||||
/// NOLINT perform
|
||||
>
|
||||
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_> {
|
||||
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_,
|
||||
Epilogue_,
|
||||
ThreadblockSwizzle_,
|
||||
KernelArch,
|
||||
GroupScheduleMode_> {
|
||||
public:
|
||||
using Base = MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_>;
|
||||
using Base = MoeFCGemm<Mma_,
|
||||
Epilogue_,
|
||||
ThreadblockSwizzle_,
|
||||
KernelArch,
|
||||
GroupScheduleMode_>;
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
@@ -711,7 +726,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {}
|
||||
Params()
|
||||
: Base::Params(),
|
||||
local_scale(nullptr),
|
||||
code_scale(nullptr),
|
||||
code_zp(nullptr) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args,
|
||||
@@ -738,7 +757,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@@ -753,7 +771,8 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
return Status::kInvalid;
|
||||
} else if (args.weight_scales == nullptr || args.local_scale == nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!");
|
||||
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is "
|
||||
"expected to be not nullptr!");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
@@ -778,38 +797,49 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static MmaQuantArguments prepare_quant_args(
|
||||
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
|
||||
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
Params const& params,
|
||||
cutlass::gemm::GemmCoord const& threadblock_offset,
|
||||
int64_t problem_idx,
|
||||
const int32_t gemm_k,
|
||||
const int32_t gemm_n,
|
||||
const int thread_idx) {
|
||||
// the begin threadblock_offset of scale, which holds the same column id
|
||||
// with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
|
||||
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
|
||||
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
|
||||
weight_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
ElementScale* weight_scale_ptr =
|
||||
params.weight_scales + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorSuperScale
|
||||
iterator_super_scale(
|
||||
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
|
||||
weight_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
|
||||
int local_scale_pointer_offset =
|
||||
((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
|
||||
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
|
||||
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
|
||||
uint4b_t* local_scale_ptr =
|
||||
reinterpret_cast<uint4b_t*>(params.local_scale + offset_in_bytes);
|
||||
|
||||
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
|
||||
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
|
||||
local_scale_ptr,
|
||||
{(gemm_k + 127) / 128, gemm_n * 2},
|
||||
thread_idx,
|
||||
tb_offset_local_scale);
|
||||
typename Mma::QuantParamsAccessor::IteratorLocalScale
|
||||
iterator_local_scale(
|
||||
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
|
||||
local_scale_ptr,
|
||||
{(gemm_k + 127) / 128, gemm_n * 2},
|
||||
thread_idx,
|
||||
tb_offset_local_scale);
|
||||
|
||||
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp
|
||||
iterator_code_scale(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
|
||||
@@ -819,8 +849,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
MmaQuantArguments mma_quant_args(
|
||||
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
|
||||
MmaQuantArguments mma_quant_args(iterator_super_scale,
|
||||
iterator_local_scale,
|
||||
iterator_code_scale,
|
||||
iterator_code_zp,
|
||||
local_scale_pointer_offset);
|
||||
return mma_quant_args;
|
||||
}
|
||||
|
||||
@@ -858,9 +891,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
// wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64.
|
||||
const int64_t quant_gemm_k = WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
|
||||
int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits<QuantElementB>::value;
|
||||
// wint2.5 and wint2.0 is quantized and packed along k dimension with
|
||||
// group_size 64.
|
||||
const int64_t quant_gemm_k =
|
||||
WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
|
||||
int64_t bytes_per_expert_matrix =
|
||||
(quant_gemm_k * gemm_n / 8) *
|
||||
cutlass::sizeof_bits<QuantElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile()) {
|
||||
@@ -881,9 +918,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
int64_t rows_to_jump = 0;
|
||||
|
||||
if (params.problem_visitor.total_rows < 0) {
|
||||
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
rows_to_jump = problem_idx == 0
|
||||
? 0
|
||||
: params.problem_visitor
|
||||
.last_row_for_problem[problem_idx - 1];
|
||||
} else {
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows /
|
||||
params.problem_visitor.problem_count);
|
||||
}
|
||||
|
||||
// begin address offset for A for current tile
|
||||
@@ -895,7 +936,8 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
|
||||
|
||||
// begin address offset for B for current problem_idx, totally num_experts problems
|
||||
// begin address offset for B for current problem_idx, totally
|
||||
// num_experts problems
|
||||
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
|
||||
problem_idx * bytes_per_expert_matrix; // NOLINT
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
@@ -904,9 +946,12 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
? gemm_n
|
||||
: gemm_k * kInterleave;
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
|
||||
// the begin threadblock_offset of B, which holds the same column id
|
||||
// with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave,
|
||||
problem_size.n() / kInterleave};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
@@ -919,14 +964,15 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B),
|
||||
ptr_B,
|
||||
extent_B,
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
LayoutB(ldm_B), ptr_B, extent_B, thread_idx, tb_offset_B);
|
||||
|
||||
MmaQuantArguments mma_quant_args = prepare_quant_args(
|
||||
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
|
||||
MmaQuantArguments mma_quant_args =
|
||||
prepare_quant_args(params,
|
||||
threadblock_offset,
|
||||
problem_idx,
|
||||
gemm_k,
|
||||
gemm_n,
|
||||
thread_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
accumulators.clear();
|
||||
@@ -965,8 +1011,10 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C =
|
||||
params.ptr_C ? reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n : nullptr;
|
||||
ElementC* ptr_C = params.ptr_C
|
||||
? reinterpret_cast<ElementC*>(params.ptr_C) +
|
||||
problem_idx * gemm_n
|
||||
: nullptr;
|
||||
ElementC* ptr_D =
|
||||
reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
@@ -1012,8 +1060,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
|
||||
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(params, shared_storage);
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010)
|
||||
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(
|
||||
params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
|
||||
@@ -26,10 +26,10 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "paddle/common/errors.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
@@ -63,24 +63,28 @@ struct CutlassLayoutB<MixedGemmArchTraits, cutlass::WintQuantMethod::kNone> {
|
||||
using Type = cutlass::layout::RowMajor;
|
||||
};
|
||||
|
||||
template <typename BaseGemmKernel, typename Arch, cutlass::WintQuantMethod Method>
|
||||
template <typename BaseGemmKernel,
|
||||
typename Arch,
|
||||
cutlass::WintQuantMethod Method>
|
||||
struct CutlassGemmKernel {
|
||||
using Type =
|
||||
cutlass::gemm::kernel::MoeFCGemm<typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
using Type = cutlass::gemm::kernel::MoeFCGemm<
|
||||
typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
};
|
||||
|
||||
template <typename BaseGemmKernel, typename Arch>
|
||||
struct CutlassGemmKernel<BaseGemmKernel, Arch, cutlass::WintQuantMethod::kWeightOnlyInt2> {
|
||||
using Type =
|
||||
cutlass::gemm::kernel::Wint2xMoeFCGemm<typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
struct CutlassGemmKernel<BaseGemmKernel,
|
||||
Arch,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt2> {
|
||||
using Type = cutlass::gemm::kernel::Wint2xMoeFCGemm<
|
||||
typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
};
|
||||
|
||||
// ======================= Variable batched Gemm things =======================
|
||||
@@ -91,21 +95,22 @@ template <typename T,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
int Stages>
|
||||
void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
const int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* kernel_occupancy = nullptr) {
|
||||
void generic_moe_gemm_kernelLauncher(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
const int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* kernel_occupancy = nullptr) {
|
||||
if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) {
|
||||
throw std::runtime_error("[MoeGemm] Grouped gemm does not support split-k");
|
||||
}
|
||||
@@ -128,12 +133,14 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
|
||||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value ||
|
||||
cutlass::platform::is_same<WeightType, uint16_t>::value,
|
||||
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)");
|
||||
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t "
|
||||
"(wint4), uint16_t (wint2.5)");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to
|
||||
// cutlass::half_t if necessary.
|
||||
using ElementType = typename cutlass::CutlassDataType<T>::Type;
|
||||
using CutlassWeightType = typename cutlass::CutlassDataType<typename WeightQuantTraits::WeightType>::Type;
|
||||
using CutlassWeightType = typename cutlass::CutlassDataType<
|
||||
typename WeightQuantTraits::WeightType>::Type;
|
||||
using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType;
|
||||
using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType;
|
||||
|
||||
@@ -155,7 +162,8 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessA,
|
||||
CutlassMmaKernelType,
|
||||
typename CutlassLayoutB<MixedGemmArchTraits, WeightQuantTraits::kQuantMethod>::Type,
|
||||
typename CutlassLayoutB<MixedGemmArchTraits,
|
||||
WeightQuantTraits::kQuantMethod>::Type,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessB,
|
||||
ElementType,
|
||||
@@ -172,7 +180,10 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
||||
|
||||
using GemmKernel = typename CutlassGemmKernel<BaseGemmKernel, arch, WeightQuantTraits::kQuantMethod>::Type;
|
||||
using GemmKernel =
|
||||
typename CutlassGemmKernel<BaseGemmKernel,
|
||||
arch,
|
||||
WeightQuantTraits::kQuantMethod>::Type;
|
||||
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
if (kernel_occupancy != nullptr) {
|
||||
@@ -194,7 +205,8 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
const uint8_t* local_scale_B = nullptr;
|
||||
const float* code_scale_B = nullptr;
|
||||
const float* code_zp_B = nullptr;
|
||||
if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
if constexpr (WeightQuantTraits::kQuantMethod ==
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
local_scale_B = quant_args_B.local_scale_ptr;
|
||||
code_scale_B = quant_args_B.code_scale_ptr;
|
||||
code_zp_B = quant_args_B.code_zp_ptr;
|
||||
@@ -253,21 +265,22 @@ template <typename T,
|
||||
int Stages,
|
||||
typename Enable = void>
|
||||
struct dispatch_stages {
|
||||
static void dispatch(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
static void dispatch(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
// FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
|
||||
std::to_string(arch::kMinComputeCapability) +
|
||||
@@ -289,21 +302,22 @@ struct dispatch_stages<T,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
2> {
|
||||
static void dispatch(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
static void dispatch(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
generic_moe_gemm_kernelLauncher<T,
|
||||
WeightQuantTraits,
|
||||
arch,
|
||||
@@ -342,21 +356,22 @@ struct dispatch_stages<T,
|
||||
WarpShape,
|
||||
Stages,
|
||||
typename std::enable_if<(Stages > 2)>::type> {
|
||||
static void dispatch(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
static void dispatch(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
generic_moe_gemm_kernelLauncher<T,
|
||||
WeightQuantTraits,
|
||||
cutlass::arch::Sm80,
|
||||
@@ -387,21 +402,22 @@ template <typename T,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
void dispatch_gemm_config(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
void dispatch_gemm_config(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
#define dispatch_stages_macro(STAGE) \
|
||||
case STAGE: \
|
||||
dispatch_stages<T, \
|
||||
@@ -468,35 +484,39 @@ void dispatch_gemm_config(const T* A,
|
||||
|
||||
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
|
||||
// This overload is only enabled when T == WeightType.
|
||||
template <typename T,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value &&
|
||||
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
template <
|
||||
typename T,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<
|
||||
!std::is_same<T, float>::value &&
|
||||
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
|
||||
dispatch_gemm_config_macro(64, 128, 64, 32, 64, 64);
|
||||
dispatch_gemm_config_macro(128, 128, 64, 64, 32, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
@@ -518,32 +538,36 @@ template <typename T,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
typename std::enable_if<
|
||||
!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::
|
||||
type* = nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
|
||||
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
if constexpr (WeightQuantTraits::kQuantMethod !=
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
|
||||
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
@@ -558,7 +582,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
|
||||
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support "
|
||||
"sm70.");
|
||||
}
|
||||
} else {
|
||||
switch (gemm_config.tile_config) {
|
||||
@@ -574,7 +599,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64);
|
||||
dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
@@ -597,22 +623,23 @@ template <
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
void dispatch_moe_gemm_to_cutlass(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8);
|
||||
case CutlassTileConfig::Undefined:
|
||||
@@ -659,33 +686,34 @@ void MoeGemmRunner<T, WeightQuantTraits>::dispatch_to_arch<EpilogueTag>(
|
||||
CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream,
|
||||
int* occupancy) {
|
||||
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
|
||||
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
|
||||
dispatch_moe_gemm_to_cutlass<T, WeightQuantTraits, ARCH, EpilogueTag>( \
|
||||
A, \
|
||||
B, \
|
||||
weight_scales, \
|
||||
biases, \
|
||||
C, \
|
||||
total_rows_before_expert, \
|
||||
total_rows, \
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
sm_, \
|
||||
multi_processor_count_, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
weight_scales, \
|
||||
biases, \
|
||||
C, \
|
||||
total_rows_before_expert, \
|
||||
total_rows, \
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
sm_, \
|
||||
multi_processor_count_, \
|
||||
stream, \
|
||||
occupancy);
|
||||
|
||||
if (sm_ >= 70 && sm_ < 75) {
|
||||
dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm70);
|
||||
} else if (sm_ >= 75 && sm_ < 80) {
|
||||
dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm75);
|
||||
} else if (sm_ >= 80 && sm_ < 91) {
|
||||
} else if (sm_ >= 80 && sm_ < 101) {
|
||||
dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm80);
|
||||
} else {
|
||||
throw std::runtime_error("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
|
||||
throw std::runtime_error(
|
||||
"[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -705,7 +733,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
cudaStream_t stream) {
|
||||
static constexpr bool is_weight_only = !std::is_same<T, typename WeightQuantTraits::WeightType>::value;
|
||||
static constexpr bool is_weight_only =
|
||||
!std::is_same<T, typename WeightQuantTraits::WeightType>::value;
|
||||
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs =
|
||||
@@ -776,7 +805,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
|
||||
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
|
||||
check_cuda_error(cudaEventDestroy(start));
|
||||
check_cuda_error(cudaEventDestroy(stop));
|
||||
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
|
||||
// std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << "
|
||||
// ms" << std::endl;
|
||||
if (elapsed < best_time) {
|
||||
best_id = ii;
|
||||
best_time = elapsed;
|
||||
@@ -789,7 +819,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
|
||||
}
|
||||
}
|
||||
if (find_one) {
|
||||
//std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl;
|
||||
// std::cout << "[TUNING] best_config: " << best_id << ", time: " <<
|
||||
// best_time << " ms" << std::endl;
|
||||
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
|
||||
chosen_config = best_config;
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user