Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -16,106 +16,127 @@
#include <mutex>
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/half.h"
#include "helper.h"
#include "paddle/extension.h"
template <paddle::DataType D>
class CutlassDtypeTraits;
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
PD_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
/**
* A wrapper for a kernel that is used to guard against compilation on
* architectures that will never use the kernel. The purpose of this is to
* reduce the size of the compiled binary.
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template <typename Kernel> struct enable_sm90_or_later : Kernel {
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
template <paddle::DataType D> class CutlassDtypeTraits;
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <>
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
};
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
};
class CutlassGemmConfigMannager {
public:
static CutlassGemmConfigMannager& getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
public:
static CutlassGemmConfigMannager &getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete;
CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) =
delete;
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
CutlassGemmConfigMannager &
operator=(const CutlassGemmConfigMannager &) = delete;
void up_date_configs(const nlohmann::json& j) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = j.begin(); it != j.end(); ++it) {
json_[it.key()] = it.value();
}
void up_date_configs(const nlohmann::json &j) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = j.begin(); it != j.end(); ++it) {
json_[it.key()] = it.value();
}
}
nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) {
if (!load_initialized_) {
std::ifstream file(config_file_path);
if (!file.good()) {
throw std::runtime_error(
"cutlass gemm_best_config can not be found, please set "
"gemm_best_config'path as "
"FLAGS_use_cutlass_device_best_config_path, or unset "
"FLAGS_use_cutlass_device_best_config_path to tune "
"gemm_best_config");
}
json_ = readJsonFromFile(config_file_path);
load_initialized_ = true;
save_initialized_ = false;
}
return &json_;
nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
if (!load_initialized_) {
std::ifstream file(config_file_path);
if (!file.good()) {
throw std::runtime_error(
"cutlass gemm_best_config can not be found, please set "
"gemm_best_config'path as "
"FLAGS_use_cutlass_device_best_config_path, or unset "
"FLAGS_use_cutlass_device_best_config_path to tune "
"gemm_best_config");
}
json_ = readJsonFromFile(config_file_path);
load_initialized_ = true;
save_initialized_ = false;
}
return &json_;
}
private:
void save_gemm_best_configs_(const std::string& config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {
std::ofstream new_file(config_file_path);
new_file << json_.dump(4);
new_file.close();
} else {
nlohmann::json old_json = readJsonFromFile(config_file_path);
for (auto it = json_.begin(); it != json_.end(); ++it) {
old_json[it.key()] = it.value();
}
json_ = old_json;
std::ofstream new_file(config_file_path,
std::ios::out | std::ios::trunc);
new_file << json_.dump(4);
new_file.close();
file.close();
}
return;
private:
void save_gemm_best_configs_(const std::string &config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {
std::ofstream new_file(config_file_path);
new_file << json_.dump(4);
new_file.close();
} else {
nlohmann::json old_json = readJsonFromFile(config_file_path);
for (auto it = json_.begin(); it != json_.end(); ++it) {
old_json[it.key()] = it.value();
}
json_ = old_json;
std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
new_file << json_.dump(4);
new_file.close();
file.close();
}
return;
}
CutlassGemmConfigMannager()
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
~CutlassGemmConfigMannager() {
std::lock_guard<std::mutex> lock(mutex_);
if (save_initialized_) {
std::string config_file_path = "fp8_fuse_gemm_config.json";
save_gemm_best_configs_(config_file_path);
}
save_initialized_ = true;
load_initialized_ = false;
json_.clear();
CutlassGemmConfigMannager()
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
~CutlassGemmConfigMannager() {
std::lock_guard<std::mutex> lock(mutex_);
if (save_initialized_) {
std::string config_file_path = "fp8_fuse_gemm_config.json";
save_gemm_best_configs_(config_file_path);
}
mutable std::mutex mutex_;
nlohmann::json json_;
bool load_initialized_;
bool save_initialized_;
save_initialized_ = true;
load_initialized_ = false;
json_.clear();
}
mutable std::mutex mutex_;
nlohmann::json json_;
bool load_initialized_;
bool save_initialized_;
};

View File

@@ -15,8 +15,8 @@
#pragma once
#include "fp8_common.h"
#include "fuse_dual_gemm_swiglu_template.h"
#include "fuse_dual_gemm_act_template_3x.h"
#include "fuse_dual_gemm_geglu_template.h"
#include "fuse_dual_gemm_swiglu_template.h"
bool fp8_fp8_dual_gemm_scale_bias_act(
DualGemmEpilogueAllParams params);
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params);

View File

@@ -15,12 +15,13 @@
#pragma once
#include "fp8_common.h"
#include "fuse_gemm_gelu_template.h"
#include "fuse_gemm_noact_template.h"
#include "fuse_gemm_relu_template.h"
#include "fuse_gemm_gelu_template.h"
#include "fuse_block_gemm_act_template_3x.h"
#include "fuse_gemm_act_template_3x.h"
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params);
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);

View File

@@ -0,0 +1,173 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/float8.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "fp8_common.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp"
#include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp"
template <typename InputType, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType,
typename TileSchedulerType = void,
template <class /* ElementCompute */> class Activation =
cutlass::epilogue::thread::SiLu,
bool SwapAB = true>
bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
using namespace cute;
using ElementA = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<
ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = ElementA; // Element type for B matrix operand
using LayoutB =
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB =
128 /
cutlass::sizeof_bits<
ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
using ElementC = ElementA; // Element type for C matrix operands
using LayoutC = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
static constexpr int AlignmentC =
128 /
cutlass::sizeof_bits<
ElementC>::value; // Memory access granularity/alignment of C matrices
// in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = ElementA; // Element type for output matrix operands
// using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output
// matrix operands
using LayoutOutput = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
static constexpr int AlignmentOutput =
128 / cutlass::sizeof_bits<ElementOutput>::value;
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for compute
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = CTAShape; // Threadblock-level tile size
using KernelSchedule = MainloopScheduleType;
using EpilogueSchedule = EpilogueScheduleType;
using TileScheduler = TileSchedulerType;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using FusionOperation =
cutlass::epilogue::fusion::ScaledAcc<ElementOutput, ElementCompute>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC,
ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule,
FusionOperation>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilderGated<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule, Activation, SwapAB>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
int arg_m = params.M;
int arg_n = params.N;
ElementA const *ptr_A = reinterpret_cast<ElementA const *>(params.A);
ElementB const *ptr_B0 = reinterpret_cast<ElementB const *>(params.B0);
ElementB const *ptr_B1 = reinterpret_cast<ElementB const *>(params.B1);
if constexpr (SwapAB) {
arg_m = params.N;
arg_n = params.M;
ptr_A = reinterpret_cast<ElementB const *>(params.B0);
ptr_B0 = reinterpret_cast<ElementA const *>(params.A);
}
StrideA stride_A = cutlass::make_cute_packed_stride(
StrideA{}, cute::make_shape(arg_m, params.K, params.batch_count));
StrideB stride_B = cutlass::make_cute_packed_stride(
StrideB{}, cute::make_shape(arg_n, params.K, params.batch_count));
StrideC stride_C;
StrideD stride_D = cutlass::make_cute_packed_stride(
StrideD{}, cute::make_shape(arg_m, arg_n, params.batch_count));
typename Gemm::Arguments arguments = {
cutlass::gemm::GemmUniversalMode::kGemm,
{arg_m, arg_n, params.K, params.batch_count},
{ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1},
{{}, // epilogue.thread
nullptr,
stride_C,
reinterpret_cast<ElementOutput *>(params.D),
stride_D}};
arguments.epilogue.thread.alpha = params.scale_out;
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::can_implement() failed" << std::endl;
return false;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
phi::Allocator *allocator = paddle::GetAllocator(params.place);
auto workspace = allocator->Allocate(workspace_size);
//
// Run the GEMM
//
status = gemm_op(arguments, workspace->ptr(), params.stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::run() failed" << std::endl;
return false;
}
return true;
}

View File

@@ -0,0 +1,151 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fp8_common.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/packed_stride.hpp"
template <
typename InputType,
typename OutType,
bool hasbias,
template <class> typename Activation,
typename TileShape,
typename ClusterShape,
typename KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum,
typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized,
typename SM = cutlass::arch::Sm90>
bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
using namespace cute;
using ElementA = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
using ElementB = ElementA;
using ElementD =
typename std::conditional_t<std::is_same_v<OutType, phi::dtype::bfloat16>,
cutlass::bfloat16_t, cutlass::half_t>;
using ElementC = std::conditional_t<hasbias, ElementD, void>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
// 16B alignment lets us use TMA
static constexpr int AlignmentA = 16 / sizeof(ElementA);
static constexpr int AlignmentB = 16 / sizeof(ElementB);
static constexpr int AlignmentC = hasbias ? 16 / sizeof(ElementC) : 8;
static constexpr int AlignmentD = 16 / sizeof(ElementD);
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using FusionOperation =
cutlass::epilogue::fusion::LinCombEltAct<Activation, ElementD,
ElementCompute, ElementC,
ElementScalar, RoundStyle>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Data members
//
/// Initialization
StrideA stride_A{params.lda, cute::Int<1>{}, params.M * params.lda};
StrideB stride_B{params.ldb, cute::Int<1>{}, params.N * params.ldb};
StrideC stride_C{0, cute::Int<1>{}, 0};
StrideD stride_D{params.ldd, cute::Int<1>{}, params.ldd * params.M};
auto a_ptr = reinterpret_cast<ElementA *>(const_cast<void *>(params.A));
auto b_ptr = reinterpret_cast<ElementB *>(const_cast<void *>(params.B));
auto c_ptr = reinterpret_cast<ElementC *>(const_cast<void *>(params.bias));
auto d_ptr = reinterpret_cast<ElementD *>(params.D);
ProblemShapeType problem_size =
ProblemShapeType{params.M, params.N, params.K, params.batch_count};
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{a_ptr, stride_A, b_ptr, stride_B},
{{params.scale}, // epilogue.thread
c_ptr,
stride_C,
d_ptr,
stride_D}};
if constexpr (hasbias) {
arguments.epilogue.thread.beta = 1.0;
}
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cout << "Gemm::can_implement() failed. "
<< cutlassGetStatusString(status) << std::endl;
return false;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
phi::Allocator *allocator = paddle::GetAllocator(params.place);
auto workspace = allocator->Allocate(workspace_size);
status = gemm_op(arguments, workspace->ptr(), params.stream);
if (status != cutlass::Status::kSuccess) {
std::cout << "Gemm::run() failed." << cutlassGetStatusString(status)
<< std::endl;
return false;
}
return true;
}

View File

@@ -43,7 +43,9 @@
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@@ -156,9 +158,6 @@ struct MoeFCGemm {
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
@@ -209,6 +208,13 @@ struct MoeFCGemm {
int64_t gemm_n;
int64_t gemm_k;
WintQuantMethod quant_method;
// Extra arguments for wint2.0
uint8_t* local_scale;
float* code_scale;
float* code_zp;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
@@ -230,6 +236,10 @@ struct MoeFCGemm {
total_rows(-1),
gemm_n(0),
gemm_k(0),
quant_method(WintQuantMethod::kNone),
local_scale(nullptr),
code_scale(nullptr),
code_zp(nullptr),
host_problem_sizes(nullptr) {}
/// Ctor
@@ -246,6 +256,10 @@ struct MoeFCGemm {
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
WintQuantMethod quant_method,
const uint8_t* local_scale,
const float* code_scale,
const float* code_zp,
GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count),
threadblock_count(threadblock_count),
@@ -259,8 +273,12 @@ struct MoeFCGemm {
total_rows(total_rows),
gemm_n(gemm_n),
gemm_k(gemm_k),
quant_method(quant_method),
local_scale(const_cast<uint8_t*>(local_scale)),
code_scale(const_cast<float*>(code_scale)),
code_zp(const_cast<float*>(code_zp)),
host_problem_sizes(nullptr) {
if (platform::is_same<uint8_t, ElementB>::value ||
if (quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
assert(weight_scales);
}
@@ -284,6 +302,8 @@ struct MoeFCGemm {
ElementC* ptr_C;
ElementC* ptr_D;
WintQuantMethod quant_method;
//
// Methods
//
@@ -294,7 +314,8 @@ struct MoeFCGemm {
ptr_B(nullptr),
weight_scales(nullptr),
ptr_C(nullptr),
ptr_D(nullptr) {}
ptr_D(nullptr),
quant_method(WintQuantMethod::kNone) {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args,
@@ -313,7 +334,8 @@ struct MoeFCGemm {
ptr_B(args.ptr_B),
weight_scales(args.weight_scales),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D) {}
ptr_D(args.ptr_D),
quant_method(args.quant_method) {}
CUTLASS_HOST_DEVICE
void update(Arguments const& args,
@@ -334,6 +356,7 @@ struct MoeFCGemm {
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
quant_method = args.quant_method;
}
};
@@ -358,7 +381,7 @@ struct MoeFCGemm {
}
static Status can_implement(Arguments const& args) {
if (platform::is_same<uint8_t, ElementB>::value ||
if (args.quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
if (args.weight_scales == nullptr) {
CUTLASS_TRACE_HOST(
@@ -394,6 +417,7 @@ struct MoeFCGemm {
template <typename dummy>
struct KernelRunner<true, dummy> {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
@@ -401,12 +425,14 @@ struct MoeFCGemm {
// These types shadow the type-level definitions and support the ability
// to implement a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(
@@ -435,6 +461,7 @@ struct MoeFCGemm {
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// threadblock_offset of C
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
@@ -450,6 +477,7 @@ struct MoeFCGemm {
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
}
// begin address offset for A for current tile
ElementA* ptr_A =
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
@@ -463,14 +491,17 @@ struct MoeFCGemm {
: gemm_k * kInterleave;
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
@@ -610,6 +641,381 @@ struct MoeFCGemm {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled
/// for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to //
/// NOLINT perform
>
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_> {
public:
using Base = MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_>;
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments = typename Base::MapArguments;
// Public-facing type definitions related to operand element type, layout, and
// complex conjugate operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC =
Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = typename Base::ProblemVisitor;
using Arguments = typename Base::Arguments;
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params : Base::Params {
// Extra arguments for wint2.0
uint8_t* local_scale;
float* code_scale;
float* code_zp;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args,
void* workspace = nullptr,
int tile_count = 0) // NOLINT
: Base::Params(args, workspace, tile_count),
local_scale(args.local_scale),
code_scale(args.code_scale),
code_zp(args.code_zp) {}
CUTLASS_HOST_DEVICE
void update(Arguments const& args,
void* workspace = nullptr,
int tile_count = 0) {
Base::update(args, workspace, tile_count);
local_scale = args.local_scale;
code_scale = args.code_scale;
code_zp = args.code_zp;
}
};
/// Shared memory storage structure
using SharedStorage = typename Base::SharedStorage;
public:
//
// Methods
//
CUTLASS_DEVICE
Wint2xMoeFCGemm() {}
static Status can_implement(Arguments const& args) {
if (args.quant_method != WintQuantMethod::kWeightOnlyInt2) {
CUTLASS_TRACE_HOST(
"Wint2xMoeFCGemm::can_implement() - only support weight_only_int2!");
return Status::kInvalid;
} else if (args.weight_scales == nullptr || args.local_scale == nullptr) {
CUTLASS_TRACE_HOST(
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!");
return Status::kInvalid;
}
return Status::kSuccess;
}
// The dummy template parameter is not used and exists so that we can compile
// this code using a standard earlier than C++17. Prior to C++17, fully
// specialized templates HAD to exists in a namespace
template <WintQuantMethod QuantMethod, bool B, typename dummy = void>
struct KernelRunner {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
CUTLASS_NOT_IMPLEMENTED();
}
};
template <WintQuantMethod QuantMethod, typename dummy>
struct KernelRunner<QuantMethod, true, dummy> {
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
using QuantArguments = typename WeightQuantTraits::Arguments;
CUTLASS_DEVICE
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
QuantArguments quant_args;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
}
return quant_args;
}
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
//
// These types shadow the type-level definitions and support the ability
// to implement a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using QuantElementB = typename WeightQuantTraits::WeightType;
using MmaElementB = typename WeightQuantTraits::MmaWeightType;
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(
platform::is_same<LayoutB, layout::RowMajor>::value &&
kInterleave == 1 ||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// LayoutB should be RowMajor
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
// wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64.
const int64_t quant_gemm_k = WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits<QuantElementB>::value;
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// threadblock_offset of C
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
0);
// begin address offset for weight_scale.
ElementScale* weight_scale_ptr =
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Load element pointers. Exchange pointers and strides if working on
// the transpose
int64_t rows_to_jump = 0;
if (params.problem_visitor.total_rows < 0) {
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
} else {
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
}
// begin address offset for A for current tile
ElementA* ptr_A =
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
// begin address offset for B for current problem_idx, totally num_experts problems
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix; // NOLINT
typename LayoutB::LongIndex ldm_B =
platform::is_same<layout::RowMajor, LayoutB>::value
? gemm_n
: gemm_k * kInterleave;
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
MmaElementB* smem_unzip_B_ptr = nullptr;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
}
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
byte_ptr_B,
ldm_B,
extent_B,
tb_offset_B,
weight_scale_ptr,
tb_offset_scale,
quant_args);
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
ptr_B,
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
thread_idx,
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the
// previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
tile_dequanter_B,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C =
params.ptr_C ? reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n : nullptr;
ElementC* ptr_D =
reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
LayoutC layout_C(0);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C,
ptr_C,
problem_size.mn(),
thread_idx,
threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D,
ptr_D,
problem_size.mn(),
thread_idx,
threadblock_offset.mn());
Epilogue epilogue(
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/*
To improve compilation speed, we do not compile the device operator if the
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params,
SharedStorage& shared_storage) { // NOLINT
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -15,16 +15,22 @@
*/
#pragma once
#include <cuda_runtime_api.h>
#include <string>
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
#include "cutlass_extensions/wint_type_traits.h"
namespace phi {
template <typename T, /*The type used for activations/scales/compute*/
typename WeightType /* The type for the MoE weights */>
typename WeightQuantTraits /* The quant traits for the MoE weights */>
class MoeGemmRunner {
public:
using WeightType = typename WeightQuantTraits::WeightType;
using Arguments = typename WeightQuantTraits::Arguments;
MoeGemmRunner();
void moe_gemm_bias_act(const T* A,
@@ -38,6 +44,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
std::string activation_type,
cudaStream_t stream);
@@ -51,6 +58,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
cudaStream_t stream);
private:
@@ -65,6 +73,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy = nullptr);
@@ -81,6 +90,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
cudaStream_t stream);
private:

View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
@@ -22,7 +22,8 @@
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
template class MoeGemmRunner<
__nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>;
#endif
} // namespace phi
} // namespace phi

View File

@@ -0,0 +1,30 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
#include "helper.h"
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<
__nv_bfloat16,
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
#endif
} // namespace phi

View File

@@ -21,7 +21,9 @@
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
template class MoeGemmRunner<
__nv_bfloat16,
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
#endif
} // namespace phi
} // namespace phi

View File

@@ -22,8 +22,9 @@
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
template class MoeGemmRunner<
__nv_bfloat16,
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
#endif
} // namespace phi
} // namespace phi

View File

@@ -21,6 +21,7 @@
namespace phi {
template class MoeGemmRunner<half, half>;
template class MoeGemmRunner<half,
cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kNone>>;
} // namespace phi
} // namespace phi

View File

@@ -0,0 +1,27 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
#include "helper.h"
namespace phi {
template class MoeGemmRunner<
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
} // namespace phi

View File

@@ -21,6 +21,7 @@
namespace phi {
template class MoeGemmRunner<half, cutlass::uint4b_t>;
template class MoeGemmRunner<
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
} // namespace phi
} // namespace phi

View File

@@ -21,6 +21,7 @@
namespace phi {
template class MoeGemmRunner<half, uint8_t>;
template class MoeGemmRunner<
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
} // namespace phi
} // namespace phi

View File

@@ -24,9 +24,10 @@
#include <math.h>
#include <optional>
#include <sstream>
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/array.h"
#include "cutlass/trace.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
@@ -35,8 +36,11 @@
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/wint_type_traits.h"
#include "cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
@@ -48,17 +52,47 @@
#include "helper.h"
namespace phi {
// ============================= Variable batched Gemm things
// ===========================
template <typename MixedGemmArchTraits, cutlass::WintQuantMethod Method>
struct CutlassLayoutB {
using Type = typename MixedGemmArchTraits::LayoutB;
};
template <typename MixedGemmArchTraits>
struct CutlassLayoutB<MixedGemmArchTraits, cutlass::WintQuantMethod::kNone> {
using Type = cutlass::layout::RowMajor;
};
template <typename BaseGemmKernel, typename Arch, cutlass::WintQuantMethod Method>
struct CutlassGemmKernel {
using Type =
cutlass::gemm::kernel::MoeFCGemm<typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
};
template <typename BaseGemmKernel, typename Arch>
struct CutlassGemmKernel<BaseGemmKernel, Arch, cutlass::WintQuantMethod::kWeightOnlyInt2> {
using Type =
cutlass::gemm::kernel::Wint2xMoeFCGemm<typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
};
// ======================= Variable batched Gemm things =======================
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
void generic_moe_gemm_kernelLauncher(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -67,6 +101,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
const int multi_processor_count,
cudaStream_t stream,
@@ -86,44 +121,26 @@ void generic_moe_gemm_kernelLauncher(const T* A,
"Specialized for half, float");
#endif
using WeightType = typename WeightQuantTraits::WeightType;
static_assert(
cutlass::platform::is_same<T, WeightType>::value ||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value ||
cutlass::platform::is_same<WeightType, uint16_t>::value,
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)");
// The cutlass type for the input elements. This is needed to convert to
// cutlass::half_t if necessary.
using ElementType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<T, half>::value,
cutlass::half_t,
T>::type;
#ifdef PADDLE_CUDA_BF16
using ElementType = typename cutlass::platform::conditional<
cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
cutlass::bfloat16_t,
ElementType_>::type;
#else
using ElementType = ElementType_;
#endif
using CutlassWeightType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<WeightType, half>::value,
cutlass::half_t,
WeightType>::type;
#ifdef PADDLE_CUDA_BF16
using CutlassWeightType = typename cutlass::platform::conditional<
cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
cutlass::bfloat16_t,
CutlassWeightType_>::type;
#else
using CutlassWeightType = CutlassWeightType_;
#endif
using ElementType = typename cutlass::CutlassDataType<T>::Type;
using CutlassWeightType = typename cutlass::CutlassDataType<typename WeightQuantTraits::WeightType>::Type;
using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType;
using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType;
// We need separate config for each architecture since we will target
// different tensorcore instructions. For float, we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::
MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
MixedGemmArchTraits<ElementType, CutlassMmaKernelType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename Epilogue<ElementType,
@@ -132,13 +149,13 @@ void generic_moe_gemm_kernelLauncher(const T* A,
EpilogueTag>::Op;
// Finally, set up the kernel.
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<
using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
ElementType,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessA,
CutlassWeightType,
typename MixedGemmArchTraits::LayoutB,
CutlassMmaKernelType,
typename CutlassLayoutB<MixedGemmArchTraits, WeightQuantTraits::kQuantMethod>::Type,
cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessB,
ElementType,
@@ -155,14 +172,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
typename MixedGemmArchTraits::Operator>::GemmKernel;
using GemmKernel =
cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma,
typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
arch, // Ensure top level arch is used
// for dispatch
GemmKernel_::kGroupScheduleMode>;
using GemmKernel = typename CutlassGemmKernel<BaseGemmKernel, arch, WeightQuantTraits::kQuantMethod>::Type;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
if (kernel_occupancy != nullptr) {
@@ -181,19 +191,32 @@ void generic_moe_gemm_kernelLauncher(const T* A,
typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f),
ElementAccumulator(0.f));
const uint8_t* local_scale_B = nullptr;
const float* code_scale_B = nullptr;
const float* code_zp_B = nullptr;
if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
local_scale_B = quant_args_B.local_scale_ptr;
code_scale_B = quant_args_B.code_scale_ptr;
code_zp_B = quant_args_B.code_zp_ptr;
}
typename GemmGrouped::Arguments args(
num_experts,
threadblock_count,
epilogue_op,
reinterpret_cast<const ElementType*>(A),
reinterpret_cast<const CutlassWeightType*>(B),
reinterpret_cast<const CutlassMmaWeightType*>(B),
reinterpret_cast<const ElementType*>(weight_scales),
reinterpret_cast<const ElementType*>(biases),
reinterpret_cast<ElementType*>(C),
total_rows_before_expert,
total_rows,
gemm_n,
gemm_k);
gemm_k,
WeightQuantTraits::kQuantMethod,
local_scale_B,
code_scale_B,
code_zp_B);
GemmGrouped gemm;
@@ -222,7 +245,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
}
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
@@ -231,7 +254,7 @@ template <typename T,
typename Enable = void>
struct dispatch_stages {
static void dispatch(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -240,6 +263,7 @@ struct dispatch_stages {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
@@ -253,20 +277,20 @@ struct dispatch_stages {
};
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
struct dispatch_stages<T,
WeightType,
WeightQuantTraits,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
2> {
static void dispatch(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -275,12 +299,13 @@ struct dispatch_stages<T,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_moe_gemm_kernelLauncher<T,
WeightType,
WeightQuantTraits,
arch,
EpilogueTag,
ThreadblockShape,
@@ -295,6 +320,7 @@ struct dispatch_stages<T,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
gemm_config,
multi_processor_count,
stream,
@@ -303,13 +329,13 @@ struct dispatch_stages<T,
};
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
struct dispatch_stages<T,
WeightType,
WeightQuantTraits,
cutlass::arch::Sm80,
EpilogueTag,
ThreadblockShape,
@@ -317,7 +343,7 @@ struct dispatch_stages<T,
Stages,
typename std::enable_if<(Stages > 2)>::type> {
static void dispatch(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -326,12 +352,13 @@ struct dispatch_stages<T,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_moe_gemm_kernelLauncher<T,
WeightType,
WeightQuantTraits,
cutlass::arch::Sm80,
EpilogueTag,
ThreadblockShape,
@@ -346,6 +373,7 @@ struct dispatch_stages<T,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
gemm_config,
multi_processor_count,
stream,
@@ -354,13 +382,13 @@ struct dispatch_stages<T,
};
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
void dispatch_gemm_config(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -369,6 +397,7 @@ void dispatch_gemm_config(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
@@ -376,7 +405,7 @@ void dispatch_gemm_config(const T* A,
#define dispatch_stages_macro(STAGE) \
case STAGE: \
dispatch_stages<T, \
WeightType, \
WeightQuantTraits, \
arch, \
EpilogueTag, \
ThreadblockShape, \
@@ -391,6 +420,7 @@ void dispatch_gemm_config(const T* A,
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
multi_processor_count, \
stream, \
@@ -414,7 +444,7 @@ void dispatch_gemm_config(const T* A,
case CutlassTileConfig:: \
CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \
dispatch_gemm_config<T, \
WeightType, \
WeightQuantTraits, \
arch, \
EpilogueTag, \
cutlass::gemm::GemmShape<AA, BB, CC>, \
@@ -425,10 +455,11 @@ void dispatch_gemm_config(const T* A,
biases, \
C, \
total_rows_before_expert, \
total_rows, \
total_rows, \
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
multi_processor_count, \
stream, \
@@ -438,14 +469,14 @@ void dispatch_gemm_config(const T* A,
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
// This overload is only enabled when T == WeightType.
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value &&
std::is_same<T, WeightType>::value>::type* =
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -454,6 +485,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
@@ -474,7 +506,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
default:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for same "
"type MoE tensorop GEMM.");
"type MoE tensorop GEMM for FP16/BF16.");
break;
}
}
@@ -483,14 +515,14 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
// Overload for quantize MoE GEMMs. We disable some warp configs here since they
// will not be used and we can improve compile time
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value &&
!std::is_same<T, WeightType>::value>::type* =
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -499,28 +531,34 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM.");
break;
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM for sm70.");
break;
}
} else {
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
}
} else {
switch (gemm_config.tile_config) {
@@ -555,12 +593,12 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
template <
typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -569,6 +607,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
@@ -594,8 +633,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
}
}
template <typename T, typename WeightType>
MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
template <typename T, typename WeightQuantTraits>
MoeGemmRunner<T, WeightQuantTraits>::MoeGemmRunner() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
sm_ = getSMVersion();
@@ -603,11 +642,11 @@ MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
}
template <typename T, typename WeightType>
template <typename T, typename WeightQuantTraits>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
void MoeGemmRunner<T, WeightQuantTraits>::dispatch_to_arch<EpilogueTag>(
const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -616,11 +655,12 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy) {
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
dispatch_moe_gemm_to_cutlass<T, WeightType, ARCH, EpilogueTag>( \
dispatch_moe_gemm_to_cutlass<T, WeightQuantTraits, ARCH, EpilogueTag>( \
A, \
B, \
weight_scales, \
@@ -631,6 +671,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
sm_, \
multi_processor_count_, \
@@ -648,25 +689,28 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
}
}
template <typename T, typename WeightType>
template <typename T, typename WeightQuantTraits>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t tune_total_rows,
int64_t actual_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
cudaStream_t stream) {
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
static constexpr bool is_weight_only = !std::is_same<T, typename WeightQuantTraits::WeightType>::value;
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
std::vector<CutlassGemmConfig> candidate_configs =
get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true);
static constexpr int warm_time = 5;
static constexpr int test_time = 10;
auto& gemmConfigManager = GemmConfigManager::Instance();
@@ -676,17 +720,19 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n, gemm_k, GemmType::MOEGEMM, dtype, wdtype, num_experts};
CutlassGemmConfig chosen_config;
auto chosen_config_optional =
gemmConfigManager.getBestConfig(gemmId, tune_total_rows);
gemmConfigManager.getBestConfig(gemmId, actual_total_rows);
if (chosen_config_optional != std::nullopt) {
chosen_config = chosen_config_optional.value();
} else {
size_t best_id = -1;
float best_time = std::numeric_limits<float>::max();
CutlassGemmConfig best_config;
int profile_total_rows =
std::min(gemmConfigManager.nextPowerOfTwo(tune_total_rows),
std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows),
gemmConfigManager.getMaxProfileM());
bool find_one = false;
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
size_t num_candidate_configs_size = candidate_configs.size();
for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) {
try {
for (int i = 0; i < warm_time; i++) {
dispatch_to_arch<EpilogueTag>(A,
@@ -699,6 +745,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n,
gemm_k,
num_experts,
quant_args_B,
candidate_configs[ii],
stream);
}
@@ -719,6 +766,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n,
gemm_k,
num_experts,
quant_args_B,
candidate_configs[ii],
stream);
}
@@ -728,7 +776,9 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
check_cuda_error(cudaEventDestroy(start));
check_cuda_error(cudaEventDestroy(stop));
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
if (elapsed < best_time) {
best_id = ii;
best_time = elapsed;
best_config = candidate_configs[ii];
}
@@ -739,6 +789,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
}
}
if (find_one) {
//std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl;
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
chosen_config = best_config;
} else {
@@ -756,23 +807,25 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n,
gemm_k,
num_experts,
quant_args_B,
chosen_config,
stream);
}
template <typename T, typename WeightType>
void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
template <typename T, typename WeightQuantTraits>
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm_bias_act(
const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t tune_total_rows,
int64_t actual_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
std::string activation_type,
cudaStream_t stream) {
if (activation_type == "none") {
@@ -784,10 +837,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
C,
total_rows_before_expert,
total_rows,
tune_total_rows,
actual_total_rows,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
stream);
} else {
run_gemm<EpilogueOpNoBias>(A,
@@ -797,27 +851,30 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
C,
total_rows_before_expert,
total_rows,
tune_total_rows,
actual_total_rows,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
stream);
}
}
}
template <typename T, typename WeightType>
void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
const WeightType* B,
const T* weight_scales,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t tune_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
cudaStream_t stream) {
template <typename T, typename WeightQuantTraits>
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t actual_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
cudaStream_t stream) {
run_gemm<EpilogueOpNoBias>(A,
B,
weight_scales,
@@ -825,10 +882,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
C,
total_rows_before_expert,
total_rows,
tune_total_rows,
actual_total_rows,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
stream);
}

View File

@@ -0,0 +1,102 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "helper.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_helper.h"
// clang-format on
namespace fastdeploy::c3x {
static inline cute::Shape<int, int, int, int>
get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) {
int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1];
return {m, n, k, 1};
}
template <typename GemmKernel>
void cutlass_gemm_caller(
phi::Place device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
cutlass::KernelHardwareInfo hw_info;
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
mainloop_args,
epilogue_args,
hw_info,
scheduler};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
phi::Allocator *allocator = paddle::GetAllocator(device);
auto workspace = allocator->Allocate(workspace_size);
auto stream = paddle::GetCurrentCUDAStream(device)->raw_stream();
cutlass::Status status = gemm_op.run(args, workspace->ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;
using ElementD = typename Gemm::ElementD;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = StrideC;
using StrideAux = StrideC;
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
auto [M, N, K, L] = prob_shape;
StrideA a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
StrideB b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
StrideC c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
StrideD d_stride =
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
StrideAux aux_stride = d_stride;
auto a_ptr = static_cast<ElementAB *>(const_cast<void *>(a.data()));
auto b_ptr = static_cast<ElementAB *>(const_cast<void *>(b.data()));
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD *>(const_cast<void *>(out.data()));
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, d_stride};
cutlass_gemm_caller<GemmKernel>(a.place(), prob_shape, mainloop_args,
epilogue_args);
}
} // namespace fastdeploy::c3x

View File

@@ -0,0 +1,149 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_helper.h"
#include "helper.h"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
must contain a public type named EVTCompute of type Sm90EVT, as well as a
static prepare_args function that constructs an EVTCompute::Arguments struct.
*/
using namespace cute;
namespace fastdeploy {
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm_sm100 {
using ElementAB = ElementAB_;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementD_>::value;
using ElementD = ElementD_;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
// MMA type
using ElementAccumulator = float;
// Epilogue types
using ElementBias = cutlass::half_t;
using ElementCompute = float;
using ElementAux = ElementD;
using LayoutAux = LayoutD;
using ElementAmax = float;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};
} // namespace fastdeploy

View File

@@ -0,0 +1,27 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
// clang-format will break include orders
// clang-format off
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on
namespace fastdeploy {
void cutlass_scaled_mm_azp_sm90_int8(
paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b,
paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj, paddle::optional<paddle::Tensor> const &azp,
paddle::optional<paddle::Tensor> const &bias) {
if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
*azp, bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,34 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
#include "helper.h"
template <typename Fp8Func, typename Int8Func>
void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b, paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias,
Fp8Func fp8_func, Int8Func int8_func) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1];
if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) &&
(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) {
// Standard per-tensor/per-token/per-channel scaling
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
PD_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"No kernel for this combination of input dtypes is implemented."));
}
}

View File

@@ -0,0 +1,35 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
#pragma once
#include "helper.h"
namespace fastdeploy {
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
} // namespace fastdeploy

View File

@@ -0,0 +1,28 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
// clang-format will break include orders
// clang-format off
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on
namespace fastdeploy {
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,125 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
// clang-format on
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* shape.
*/
namespace fastdeploy {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_default {
// M in (128, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.dims()[0];
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_epilogue(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...epilogue_args) {
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,29 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
// clang-format will break include orders
// clang-format off
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on
namespace fastdeploy {
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,168 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
// clang-format on
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* Gemm shape.
*/
namespace fastdeploy {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_default {
// For M > 128 and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M128 {
// For M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M64 {
// For M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NBig {
// For M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NSmall {
// For M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.dims()[1];
bool const is_small_n = n < 8192;
uint32_t const m = a.dims()[0];
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...epilogue_args) {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,200 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
#include "helper.h"
#include <stddef.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace fastdeploy;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == paddle::DataType::INT8) {
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -0,0 +1,223 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
#pragma once
#include <stddef.h>
#include "helper.h"
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass_helper.h"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace fastdeploy {
using namespace cute;
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
FP8MathOperator>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
float, cutlass::layout::RowMajor, AlignmentCD,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.dims()[0];
int32_t n = b.dims()[0];
int32_t k = a.dims()[1];
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.strides()[0];
int64_t ldb = b.strides()[0];
int64_t ldc = out.strides()[0];
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data());
auto b_ptr = static_cast<ElementAB const*>(b.data());
auto c_ptr = static_cast<ElementD*>(out.data());
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
phi::Allocator *allocator = paddle::GetAllocator(a.place());
auto workspace = allocator->Allocate(workspace_size);
auto stream = a.stream();
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace->ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static const int max_shared_mem_per_block_opt_in =
get_cuda_max_shared_memory_per_block_opt_in(0);
size_t const gemm_shared_mem_size =
sizeof(typename Gemm::KernelType::SharedStorage);
size_t const fallback_gemm_shared_mem_size =
sizeof(typename FallbackGemm::KernelType::SharedStorage);
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
PD_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,125 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_default {
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M256 {
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M64 {
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M32 {
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM256 =
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
using Cutlass2xGemmM64 =
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using FallbackGemm = Cutlass2xGemmDefault;
uint32_t const m = a.dims()[0];;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// M in [1, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,141 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_default {
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M64 {
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M32 {
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M16 {
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128BigN =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128SmallN =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM64 =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM16 =
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using FallbackGemm =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
uint32_t const m = a.dims()[0];;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
uint32_t const n = out.dims()[1];;
bool const small_n = n < 8192;
if (small_n) {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else {
// M in (128, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,370 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5,
FP8MathOperator>;
};
struct sm89_fp8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M128 {
// M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8196) {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
static const int32_t MainLoopStages = 5;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 24576) {
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
uint32_t const m = a.dims()[0];;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,355 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_int8_fallback_gemm {
// Shared mem requirement : 61440
static_assert(std::is_same<InType, int8_t>());
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
static int32_t const MainLoopStages = 5;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
struct sm89_int8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M128 {
// M in (64, 128]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[0];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
uint32_t const m = a.dims()[0];
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,37 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
fastdeploy::cutlass_scaled_mm_sm90_fp8,
fastdeploy::cutlass_scaled_mm_sm90_int8);
}
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
fastdeploy::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -0,0 +1,224 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
#pragma once
#include "helper.h"
#include <iostream>
void cutlass_scaled_mm_sm75(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_sm80(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_sm89(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
#endif
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b, paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
// Checks for conformality
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
c.dims().size() == 2);
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
b.dims()[0] == c.dims()[1]);
// Check for strides and alignment
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
PD_CHECK(b.strides()[1] == 1); // Column-major
PD_CHECK(c.strides()[0] % 16 == 0 &&
b.strides()[0] % 16 == 0); // 16 Byte Alignment
if (bias) {
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous() &&
bias->dims().size() == 1);
}
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: %d",
version_num));
}
void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
// Checks for conformality
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
c.dims().size() == 2);
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
b.dims()[0] == c.dims()[1]);
PD_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]);
PD_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0]);
// Check for strides and alignment
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
PD_CHECK(b.strides()[1] == 1); // Column-major
PD_CHECK(c.strides()[0] % 16 == 0 &&
b.strides()[0] % 16 == 0); // 16 Byte Alignment
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous());
}
if (azp) {
PD_CHECK(azp->numel() == a.dims()[0] && azp->is_contiguous());
}
PD_CHECK(azp_adj.numel() == b.dims()[0] && azp_adj.is_contiguous());
// azp & bias types
PD_CHECK(azp_adj.dtype() == paddle::DataType::INT32);
PD_CHECK(!azp || azp->dtype() == paddle::DataType::INT32);
PD_CHECK(!bias || bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
PD_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: %d",
version_num));
}
PD_BUILD_STATIC_OP(cutlass_scaled_mm)
.Inputs({"c", "a", "b", "a_scales", "b_scales", paddle::Optional("bias")})
.Outputs({"c_out"})
.SetInplaceMap({{"c", "c_out"}})
.SetKernelFn(PD_KERNEL(CutlassScaledMm));
PD_BUILD_STATIC_OP(cutlass_scaled_mm_azp)
.Inputs({"c", "a", "b", "a_scales", "b_scales", "azp_adj", paddle::Optional("azp"), paddle::Optional("bias")})
.Outputs({"c_out"})
.SetInplaceMap({{"c", "c_out"}})
.SetKernelFn(PD_KERNEL(CutlassScaledMmAzp));