[Feature] Support fusedmoe on Blackwell (#5325)

* update sm100

* fix

* fix style
This commit is contained in:
Echo-Nie
2025-12-16 11:58:50 +08:00
committed by GitHub
parent 63fff8df70
commit 50100f98d7
2 changed files with 319 additions and 239 deletions

View File

@@ -277,7 +277,8 @@ struct MoeFCGemm {
code_scale(const_cast<float*>(code_scale)),
code_zp(const_cast<float*>(code_zp)),
host_problem_sizes(nullptr) {
if (quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
if (quant_method != WintQuantMethod::kNone ||
platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
assert(weight_scales);
}
@@ -380,7 +381,8 @@ struct MoeFCGemm {
}
static Status can_implement(Arguments const& args) {
if (args.quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
if (args.quant_method != WintQuantMethod::kNone ||
platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
if (args.weight_scales == nullptr) {
CUTLASS_TRACE_HOST(
@@ -416,7 +418,6 @@ struct MoeFCGemm {
template <typename dummy>
struct KernelRunner<true, dummy> {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
@@ -471,9 +472,13 @@ struct MoeFCGemm {
int64_t rows_to_jump = 0;
if (params.problem_visitor.total_rows < 0) {
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
rows_to_jump = problem_idx == 0
? 0
: params.problem_visitor
.last_row_for_problem[problem_idx - 1];
} else {
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
rows_to_jump = problem_idx * (params.problem_visitor.total_rows /
params.problem_visitor.problem_count);
}
// begin address offset for A for current tile
@@ -496,11 +501,13 @@ struct MoeFCGemm {
0,
};
// the begin threadblock_offset of B, which holds the same column id with C
// the begin threadblock_offset of B, which holds the same column id
// with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
// the begin threadblock_offset of scale, which holds the same column id
// with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
@@ -628,7 +635,7 @@ struct MoeFCGemm {
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm75>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010)
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
@@ -649,9 +656,17 @@ template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to //
/// NOLINT perform
>
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_> {
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_,
Epilogue_,
ThreadblockSwizzle_,
KernelArch,
GroupScheduleMode_> {
public:
using Base = MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_>;
using Base = MoeFCGemm<Mma_,
Epilogue_,
ThreadblockSwizzle_,
KernelArch,
GroupScheduleMode_>;
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
@@ -711,7 +726,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
//
CUTLASS_HOST_DEVICE
Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {}
Params()
: Base::Params(),
local_scale(nullptr),
code_scale(nullptr),
code_zp(nullptr) {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args,
@@ -738,7 +757,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
using SharedStorage = typename Base::SharedStorage;
public:
//
// Methods
//
@@ -753,7 +771,8 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
return Status::kInvalid;
} else if (args.weight_scales == nullptr || args.local_scale == nullptr) {
CUTLASS_TRACE_HOST(
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!");
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is "
"expected to be not nullptr!");
return Status::kInvalid;
}
return Status::kSuccess;
@@ -778,38 +797,49 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
CUTLASS_DEVICE
static MmaQuantArguments prepare_quant_args(
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
Params const& params,
cutlass::gemm::GemmCoord const& threadblock_offset,
int64_t problem_idx,
const int32_t gemm_k,
const int32_t gemm_n,
const int thread_idx) {
// the begin threadblock_offset of scale, which holds the same column id
// with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
weight_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
ElementScale* weight_scale_ptr =
params.weight_scales + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorSuperScale
iterator_super_scale(
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
weight_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
int local_scale_pointer_offset =
((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
uint4b_t* local_scale_ptr =
reinterpret_cast<uint4b_t*>(params.local_scale + offset_in_bytes);
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
local_scale_ptr,
{(gemm_k + 127) / 128, gemm_n * 2},
thread_idx,
tb_offset_local_scale);
typename Mma::QuantParamsAccessor::IteratorLocalScale
iterator_local_scale(
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
local_scale_ptr,
{(gemm_k + 127) / 128, gemm_n * 2},
thread_idx,
tb_offset_local_scale);
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp
iterator_code_scale(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
@@ -819,8 +849,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
thread_idx,
tb_offset_scale);
MmaQuantArguments mma_quant_args(
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
MmaQuantArguments mma_quant_args(iterator_super_scale,
iterator_local_scale,
iterator_code_scale,
iterator_code_zp,
local_scale_pointer_offset);
return mma_quant_args;
}
@@ -858,9 +891,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
// wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64.
const int64_t quant_gemm_k = WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits<QuantElementB>::value;
// wint2.5 and wint2.0 is quantized and packed along k dimension with
// group_size 64.
const int64_t quant_gemm_k =
WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
int64_t bytes_per_expert_matrix =
(quant_gemm_k * gemm_n / 8) *
cutlass::sizeof_bits<QuantElementB>::value;
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
@@ -881,9 +918,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
int64_t rows_to_jump = 0;
if (params.problem_visitor.total_rows < 0) {
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
rows_to_jump = problem_idx == 0
? 0
: params.problem_visitor
.last_row_for_problem[problem_idx - 1];
} else {
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
rows_to_jump = problem_idx * (params.problem_visitor.total_rows /
params.problem_visitor.problem_count);
}
// begin address offset for A for current tile
@@ -895,7 +936,8 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
// begin address offset for B for current problem_idx, totally num_experts problems
// begin address offset for B for current problem_idx, totally
// num_experts problems
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix; // NOLINT
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
@@ -904,9 +946,12 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
? gemm_n
: gemm_k * kInterleave;
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
// the begin threadblock_offset of B, which holds the same column id
// with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave,
problem_size.n() / kInterleave};
// Compute position within threadblock
int thread_idx = threadIdx.x;
@@ -919,14 +964,15 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
ptr_B,
extent_B,
thread_idx,
tb_offset_B);
LayoutB(ldm_B), ptr_B, extent_B, thread_idx, tb_offset_B);
MmaQuantArguments mma_quant_args = prepare_quant_args(
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
MmaQuantArguments mma_quant_args =
prepare_quant_args(params,
threadblock_offset,
problem_idx,
gemm_k,
gemm_n,
thread_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
@@ -965,8 +1011,10 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C =
params.ptr_C ? reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n : nullptr;
ElementC* ptr_C = params.ptr_C
? reinterpret_cast<ElementC*>(params.ptr_C) +
problem_idx * gemm_n
: nullptr;
ElementC* ptr_D =
reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
@@ -1012,8 +1060,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
CUTLASS_DEVICE
void operator()(Params const& params,
SharedStorage& shared_storage) { // NOLINT
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(params, shared_storage);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010)
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(
params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif

View File

@@ -26,10 +26,10 @@
#include <sstream>
#include "cutlass/array.h"
#include "cutlass/trace.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/trace.h"
#include "paddle/common/errors.h"
#include "paddle/phi/core/enforce.h"
@@ -63,24 +63,28 @@ struct CutlassLayoutB<MixedGemmArchTraits, cutlass::WintQuantMethod::kNone> {
using Type = cutlass::layout::RowMajor;
};
template <typename BaseGemmKernel, typename Arch, cutlass::WintQuantMethod Method>
template <typename BaseGemmKernel,
typename Arch,
cutlass::WintQuantMethod Method>
struct CutlassGemmKernel {
using Type =
cutlass::gemm::kernel::MoeFCGemm<typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
using Type = cutlass::gemm::kernel::MoeFCGemm<
typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
};
template <typename BaseGemmKernel, typename Arch>
struct CutlassGemmKernel<BaseGemmKernel, Arch, cutlass::WintQuantMethod::kWeightOnlyInt2> {
using Type =
cutlass::gemm::kernel::Wint2xMoeFCGemm<typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
struct CutlassGemmKernel<BaseGemmKernel,
Arch,
cutlass::WintQuantMethod::kWeightOnlyInt2> {
using Type = cutlass::gemm::kernel::Wint2xMoeFCGemm<
typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
};
// ======================= Variable batched Gemm things =======================
@@ -91,21 +95,22 @@ template <typename T,
typename ThreadblockShape,
typename WarpShape,
int Stages>
void generic_moe_gemm_kernelLauncher(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
const int multi_processor_count,
cudaStream_t stream,
int* kernel_occupancy = nullptr) {
void generic_moe_gemm_kernelLauncher(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
const int multi_processor_count,
cudaStream_t stream,
int* kernel_occupancy = nullptr) {
if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) {
throw std::runtime_error("[MoeGemm] Grouped gemm does not support split-k");
}
@@ -128,12 +133,14 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cutlass::platform::is_same<WeightType, uint8_t>::value ||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value ||
cutlass::platform::is_same<WeightType, uint16_t>::value,
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)");
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t "
"(wint4), uint16_t (wint2.5)");
// The cutlass type for the input elements. This is needed to convert to
// cutlass::half_t if necessary.
using ElementType = typename cutlass::CutlassDataType<T>::Type;
using CutlassWeightType = typename cutlass::CutlassDataType<typename WeightQuantTraits::WeightType>::Type;
using CutlassWeightType = typename cutlass::CutlassDataType<
typename WeightQuantTraits::WeightType>::Type;
using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType;
using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType;
@@ -155,7 +162,8 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessA,
CutlassMmaKernelType,
typename CutlassLayoutB<MixedGemmArchTraits, WeightQuantTraits::kQuantMethod>::Type,
typename CutlassLayoutB<MixedGemmArchTraits,
WeightQuantTraits::kQuantMethod>::Type,
cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessB,
ElementType,
@@ -172,7 +180,10 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
typename MixedGemmArchTraits::Operator>::GemmKernel;
using GemmKernel = typename CutlassGemmKernel<BaseGemmKernel, arch, WeightQuantTraits::kQuantMethod>::Type;
using GemmKernel =
typename CutlassGemmKernel<BaseGemmKernel,
arch,
WeightQuantTraits::kQuantMethod>::Type;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
if (kernel_occupancy != nullptr) {
@@ -194,7 +205,8 @@ void generic_moe_gemm_kernelLauncher(const T* A,
const uint8_t* local_scale_B = nullptr;
const float* code_scale_B = nullptr;
const float* code_zp_B = nullptr;
if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
if constexpr (WeightQuantTraits::kQuantMethod ==
cutlass::WintQuantMethod::kWeightOnlyInt2) {
local_scale_B = quant_args_B.local_scale_ptr;
code_scale_B = quant_args_B.code_scale_ptr;
code_zp_B = quant_args_B.code_zp_ptr;
@@ -253,21 +265,22 @@ template <typename T,
int Stages,
typename Enable = void>
struct dispatch_stages {
static void dispatch(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
static void dispatch(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
// FT_LOG_DEBUG(__PRETTY_FUNCTION__);
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
std::to_string(arch::kMinComputeCapability) +
@@ -289,21 +302,22 @@ struct dispatch_stages<T,
ThreadblockShape,
WarpShape,
2> {
static void dispatch(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
static void dispatch(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_moe_gemm_kernelLauncher<T,
WeightQuantTraits,
arch,
@@ -342,21 +356,22 @@ struct dispatch_stages<T,
WarpShape,
Stages,
typename std::enable_if<(Stages > 2)>::type> {
static void dispatch(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
static void dispatch(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_moe_gemm_kernelLauncher<T,
WeightQuantTraits,
cutlass::arch::Sm80,
@@ -387,21 +402,22 @@ template <typename T,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
void dispatch_gemm_config(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
void dispatch_gemm_config(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
#define dispatch_stages_macro(STAGE) \
case STAGE: \
dispatch_stages<T, \
@@ -468,35 +484,39 @@ void dispatch_gemm_config(const T* A,
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
// This overload is only enabled when T == WeightType.
template <typename T,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value &&
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
template <
typename T,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<
!std::is_same<T, float>::value &&
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 32, 64, 64);
dispatch_gemm_config_macro(128, 128, 64, 64, 32, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
@@ -518,32 +538,36 @@ template <typename T,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value &&
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
typename std::enable_if<
!std::is_same<T, float>::value &&
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::
type* = nullptr>
void dispatch_moe_gemm_to_cutlass(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
if constexpr (WeightQuantTraits::kQuantMethod !=
cutlass::WintQuantMethod::kWeightOnlyInt2) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
@@ -558,7 +582,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
}
} else {
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support "
"sm70.");
}
} else {
switch (gemm_config.tile_config) {
@@ -574,7 +599,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64);
dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
@@ -597,22 +623,23 @@ template <
typename arch,
typename EpilogueTag,
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
void dispatch_moe_gemm_to_cutlass(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8);
case CutlassTileConfig::Undefined:
@@ -659,33 +686,34 @@ void MoeGemmRunner<T, WeightQuantTraits>::dispatch_to_arch<EpilogueTag>(
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy) {
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
dispatch_moe_gemm_to_cutlass<T, WeightQuantTraits, ARCH, EpilogueTag>( \
A, \
B, \
weight_scales, \
biases, \
C, \
total_rows_before_expert, \
total_rows, \
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
sm_, \
multi_processor_count_, \
stream, \
A, \
B, \
weight_scales, \
biases, \
C, \
total_rows_before_expert, \
total_rows, \
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
sm_, \
multi_processor_count_, \
stream, \
occupancy);
if (sm_ >= 70 && sm_ < 75) {
dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm70);
} else if (sm_ >= 75 && sm_ < 80) {
dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm75);
} else if (sm_ >= 80 && sm_ < 91) {
} else if (sm_ >= 80 && sm_ < 101) {
dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm80);
} else {
throw std::runtime_error("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
throw std::runtime_error(
"[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
}
}
@@ -705,7 +733,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
cudaStream_t stream) {
static constexpr bool is_weight_only = !std::is_same<T, typename WeightQuantTraits::WeightType>::value;
static constexpr bool is_weight_only =
!std::is_same<T, typename WeightQuantTraits::WeightType>::value;
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
std::vector<CutlassGemmConfig> candidate_configs =
@@ -776,7 +805,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
check_cuda_error(cudaEventDestroy(start));
check_cuda_error(cudaEventDestroy(stop));
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
// std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << "
// ms" << std::endl;
if (elapsed < best_time) {
best_id = ii;
best_time = elapsed;
@@ -789,7 +819,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
}
}
if (find_one) {
//std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl;
// std::cout << "[TUNING] best_config: " << best_id << ", time: " <<
// best_time << " ms" << std::endl;
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
chosen_config = best_config;
} else {