mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -16,106 +16,127 @@
|
||||
#include <mutex>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
class CutlassDtypeTraits;
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
PD_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
/**
|
||||
* A wrapper for a kernel that is used to guard against compilation on
|
||||
* architectures that will never use the kernel. The purpose of this is to
|
||||
* reduce the size of the compiled binary.
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel> struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef cutlass::half_t DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
template <paddle::DataType D> class CutlassDtypeTraits;
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef cutlass::bfloat16_t DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef cutlass::half_t DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
};
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef cutlass::bfloat16_t DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
};
|
||||
|
||||
class CutlassGemmConfigMannager {
|
||||
public:
|
||||
static CutlassGemmConfigMannager& getInstance() {
|
||||
static CutlassGemmConfigMannager instance;
|
||||
return instance;
|
||||
}
|
||||
public:
|
||||
static CutlassGemmConfigMannager &getInstance() {
|
||||
static CutlassGemmConfigMannager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete;
|
||||
CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) =
|
||||
delete;
|
||||
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
|
||||
CutlassGemmConfigMannager &
|
||||
operator=(const CutlassGemmConfigMannager &) = delete;
|
||||
|
||||
void up_date_configs(const nlohmann::json& j) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto it = j.begin(); it != j.end(); ++it) {
|
||||
json_[it.key()] = it.value();
|
||||
}
|
||||
void up_date_configs(const nlohmann::json &j) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto it = j.begin(); it != j.end(); ++it) {
|
||||
json_[it.key()] = it.value();
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) {
|
||||
if (!load_initialized_) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error(
|
||||
"cutlass gemm_best_config can not be found, please set "
|
||||
"gemm_best_config'path as "
|
||||
"FLAGS_use_cutlass_device_best_config_path, or unset "
|
||||
"FLAGS_use_cutlass_device_best_config_path to tune "
|
||||
"gemm_best_config");
|
||||
}
|
||||
json_ = readJsonFromFile(config_file_path);
|
||||
load_initialized_ = true;
|
||||
save_initialized_ = false;
|
||||
}
|
||||
return &json_;
|
||||
nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
|
||||
if (!load_initialized_) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error(
|
||||
"cutlass gemm_best_config can not be found, please set "
|
||||
"gemm_best_config'path as "
|
||||
"FLAGS_use_cutlass_device_best_config_path, or unset "
|
||||
"FLAGS_use_cutlass_device_best_config_path to tune "
|
||||
"gemm_best_config");
|
||||
}
|
||||
json_ = readJsonFromFile(config_file_path);
|
||||
load_initialized_ = true;
|
||||
save_initialized_ = false;
|
||||
}
|
||||
return &json_;
|
||||
}
|
||||
|
||||
private:
|
||||
void save_gemm_best_configs_(const std::string& config_file_path) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
std::ofstream new_file(config_file_path);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
} else {
|
||||
nlohmann::json old_json = readJsonFromFile(config_file_path);
|
||||
for (auto it = json_.begin(); it != json_.end(); ++it) {
|
||||
old_json[it.key()] = it.value();
|
||||
}
|
||||
json_ = old_json;
|
||||
std::ofstream new_file(config_file_path,
|
||||
std::ios::out | std::ios::trunc);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
file.close();
|
||||
}
|
||||
return;
|
||||
private:
|
||||
void save_gemm_best_configs_(const std::string &config_file_path) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
std::ofstream new_file(config_file_path);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
} else {
|
||||
nlohmann::json old_json = readJsonFromFile(config_file_path);
|
||||
for (auto it = json_.begin(); it != json_.end(); ++it) {
|
||||
old_json[it.key()] = it.value();
|
||||
}
|
||||
json_ = old_json;
|
||||
std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
file.close();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
CutlassGemmConfigMannager()
|
||||
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
|
||||
~CutlassGemmConfigMannager() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (save_initialized_) {
|
||||
std::string config_file_path = "fp8_fuse_gemm_config.json";
|
||||
save_gemm_best_configs_(config_file_path);
|
||||
}
|
||||
save_initialized_ = true;
|
||||
load_initialized_ = false;
|
||||
json_.clear();
|
||||
CutlassGemmConfigMannager()
|
||||
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
|
||||
~CutlassGemmConfigMannager() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (save_initialized_) {
|
||||
std::string config_file_path = "fp8_fuse_gemm_config.json";
|
||||
save_gemm_best_configs_(config_file_path);
|
||||
}
|
||||
mutable std::mutex mutex_;
|
||||
nlohmann::json json_;
|
||||
bool load_initialized_;
|
||||
bool save_initialized_;
|
||||
save_initialized_ = true;
|
||||
load_initialized_ = false;
|
||||
json_.clear();
|
||||
}
|
||||
mutable std::mutex mutex_;
|
||||
nlohmann::json json_;
|
||||
bool load_initialized_;
|
||||
bool save_initialized_;
|
||||
};
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "fp8_common.h"
|
||||
#include "fuse_dual_gemm_swiglu_template.h"
|
||||
#include "fuse_dual_gemm_act_template_3x.h"
|
||||
#include "fuse_dual_gemm_geglu_template.h"
|
||||
#include "fuse_dual_gemm_swiglu_template.h"
|
||||
|
||||
bool fp8_fp8_dual_gemm_scale_bias_act(
|
||||
DualGemmEpilogueAllParams params);
|
||||
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params);
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "fp8_common.h"
|
||||
#include "fuse_gemm_gelu_template.h"
|
||||
#include "fuse_gemm_noact_template.h"
|
||||
#include "fuse_gemm_relu_template.h"
|
||||
#include "fuse_gemm_gelu_template.h"
|
||||
|
||||
#include "fuse_block_gemm_act_template_3x.h"
|
||||
#include "fuse_gemm_act_template_3x.h"
|
||||
|
||||
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params);
|
||||
|
||||
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);
|
||||
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "fp8_common.h"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp"
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp"
|
||||
|
||||
template <typename InputType, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void,
|
||||
template <class /* ElementCompute */> class Activation =
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
bool SwapAB = true>
|
||||
bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
|
||||
using namespace cute;
|
||||
using ElementA = typename std::conditional_t<
|
||||
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
|
||||
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
static constexpr int AlignmentA =
|
||||
128 /
|
||||
cutlass::sizeof_bits<
|
||||
ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = ElementA; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
static constexpr int AlignmentB =
|
||||
128 /
|
||||
cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementC = ElementA; // Element type for C matrix operands
|
||||
|
||||
using LayoutC = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor>;
|
||||
static constexpr int AlignmentC =
|
||||
128 /
|
||||
cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C matrices
|
||||
// in units of elements (up to 16 bytes)
|
||||
|
||||
// Output matrix configuration
|
||||
using ElementOutput = ElementA; // Element type for output matrix operands
|
||||
// using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output
|
||||
// matrix operands
|
||||
using LayoutOutput = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor>;
|
||||
static constexpr int AlignmentOutput =
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for compute
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = CTAShape; // Threadblock-level tile size
|
||||
using KernelSchedule = MainloopScheduleType;
|
||||
using EpilogueSchedule = EpilogueScheduleType;
|
||||
using TileScheduler = TileSchedulerType;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using FusionOperation =
|
||||
cutlass::epilogue::fusion::ScaledAcc<ElementOutput, ElementCompute>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC,
|
||||
ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule,
|
||||
FusionOperation>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilderGated<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
||||
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule, Activation, SwapAB>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
int arg_m = params.M;
|
||||
int arg_n = params.N;
|
||||
ElementA const *ptr_A = reinterpret_cast<ElementA const *>(params.A);
|
||||
ElementB const *ptr_B0 = reinterpret_cast<ElementB const *>(params.B0);
|
||||
ElementB const *ptr_B1 = reinterpret_cast<ElementB const *>(params.B1);
|
||||
if constexpr (SwapAB) {
|
||||
arg_m = params.N;
|
||||
arg_n = params.M;
|
||||
ptr_A = reinterpret_cast<ElementB const *>(params.B0);
|
||||
ptr_B0 = reinterpret_cast<ElementA const *>(params.A);
|
||||
}
|
||||
StrideA stride_A = cutlass::make_cute_packed_stride(
|
||||
StrideA{}, cute::make_shape(arg_m, params.K, params.batch_count));
|
||||
StrideB stride_B = cutlass::make_cute_packed_stride(
|
||||
StrideB{}, cute::make_shape(arg_n, params.K, params.batch_count));
|
||||
StrideC stride_C;
|
||||
StrideD stride_D = cutlass::make_cute_packed_stride(
|
||||
StrideD{}, cute::make_shape(arg_m, arg_n, params.batch_count));
|
||||
|
||||
typename Gemm::Arguments arguments = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{arg_m, arg_n, params.K, params.batch_count},
|
||||
{ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_C,
|
||||
reinterpret_cast<ElementOutput *>(params.D),
|
||||
stride_D}};
|
||||
arguments.epilogue.thread.alpha = params.scale_out;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::can_implement() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(params.place);
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
status = gemm_op(arguments, workspace->ptr(), params.stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::run() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "fp8_common.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
template <
|
||||
typename InputType,
|
||||
typename OutType,
|
||||
bool hasbias,
|
||||
template <class> typename Activation,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum,
|
||||
typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized,
|
||||
typename SM = cutlass::arch::Sm90>
|
||||
bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
|
||||
using namespace cute;
|
||||
using ElementA = typename std::conditional_t<
|
||||
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
|
||||
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
|
||||
using ElementB = ElementA;
|
||||
using ElementD =
|
||||
typename std::conditional_t<std::is_same_v<OutType, phi::dtype::bfloat16>,
|
||||
cutlass::bfloat16_t, cutlass::half_t>;
|
||||
using ElementC = std::conditional_t<hasbias, ElementD, void>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 16 / sizeof(ElementA);
|
||||
static constexpr int AlignmentB = 16 / sizeof(ElementB);
|
||||
static constexpr int AlignmentC = hasbias ? 16 / sizeof(ElementC) : 8;
|
||||
static constexpr int AlignmentD = 16 / sizeof(ElementD);
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
using FusionOperation =
|
||||
cutlass::epilogue::fusion::LinCombEltAct<Activation, ElementD,
|
||||
ElementCompute, ElementC,
|
||||
ElementScalar, RoundStyle>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
|
||||
AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A{params.lda, cute::Int<1>{}, params.M * params.lda};
|
||||
StrideB stride_B{params.ldb, cute::Int<1>{}, params.N * params.ldb};
|
||||
StrideC stride_C{0, cute::Int<1>{}, 0};
|
||||
StrideD stride_D{params.ldd, cute::Int<1>{}, params.ldd * params.M};
|
||||
|
||||
auto a_ptr = reinterpret_cast<ElementA *>(const_cast<void *>(params.A));
|
||||
auto b_ptr = reinterpret_cast<ElementB *>(const_cast<void *>(params.B));
|
||||
auto c_ptr = reinterpret_cast<ElementC *>(const_cast<void *>(params.bias));
|
||||
auto d_ptr = reinterpret_cast<ElementD *>(params.D);
|
||||
|
||||
ProblemShapeType problem_size =
|
||||
ProblemShapeType{params.M, params.N, params.K, params.batch_count};
|
||||
|
||||
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{a_ptr, stride_A, b_ptr, stride_B},
|
||||
{{params.scale}, // epilogue.thread
|
||||
c_ptr,
|
||||
stride_C,
|
||||
d_ptr,
|
||||
stride_D}};
|
||||
if constexpr (hasbias) {
|
||||
arguments.epilogue.thread.beta = 1.0;
|
||||
}
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Gemm::can_implement() failed. "
|
||||
<< cutlassGetStatusString(status) << std::endl;
|
||||
return false;
|
||||
}
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(params.place);
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
status = gemm_op(arguments, workspace->ptr(), params.stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Gemm::run() failed." << cutlassGetStatusString(status)
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -43,7 +43,9 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@@ -156,9 +158,6 @@ struct MoeFCGemm {
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
@@ -209,6 +208,13 @@ struct MoeFCGemm {
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
|
||||
WintQuantMethod quant_method;
|
||||
|
||||
// Extra arguments for wint2.0
|
||||
uint8_t* local_scale;
|
||||
float* code_scale;
|
||||
float* code_zp;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
@@ -230,6 +236,10 @@ struct MoeFCGemm {
|
||||
total_rows(-1),
|
||||
gemm_n(0),
|
||||
gemm_k(0),
|
||||
quant_method(WintQuantMethod::kNone),
|
||||
local_scale(nullptr),
|
||||
code_scale(nullptr),
|
||||
code_zp(nullptr),
|
||||
host_problem_sizes(nullptr) {}
|
||||
|
||||
/// Ctor
|
||||
@@ -246,6 +256,10 @@ struct MoeFCGemm {
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
WintQuantMethod quant_method,
|
||||
const uint8_t* local_scale,
|
||||
const float* code_scale,
|
||||
const float* code_zp,
|
||||
GemmCoord* host_problem_sizes = nullptr)
|
||||
: problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
@@ -259,8 +273,12 @@ struct MoeFCGemm {
|
||||
total_rows(total_rows),
|
||||
gemm_n(gemm_n),
|
||||
gemm_k(gemm_k),
|
||||
quant_method(quant_method),
|
||||
local_scale(const_cast<uint8_t*>(local_scale)),
|
||||
code_scale(const_cast<float*>(code_scale)),
|
||||
code_zp(const_cast<float*>(code_zp)),
|
||||
host_problem_sizes(nullptr) {
|
||||
if (platform::is_same<uint8_t, ElementB>::value ||
|
||||
if (quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
|
||||
platform::is_same<uint4b_t, ElementB>::value) {
|
||||
assert(weight_scales);
|
||||
}
|
||||
@@ -284,6 +302,8 @@ struct MoeFCGemm {
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
WintQuantMethod quant_method;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@@ -294,7 +314,8 @@ struct MoeFCGemm {
|
||||
ptr_B(nullptr),
|
||||
weight_scales(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr) {}
|
||||
ptr_D(nullptr),
|
||||
quant_method(WintQuantMethod::kNone) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args,
|
||||
@@ -313,7 +334,8 @@ struct MoeFCGemm {
|
||||
ptr_B(args.ptr_B),
|
||||
weight_scales(args.weight_scales),
|
||||
ptr_C(args.ptr_C),
|
||||
ptr_D(args.ptr_D) {}
|
||||
ptr_D(args.ptr_D),
|
||||
quant_method(args.quant_method) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args,
|
||||
@@ -334,6 +356,7 @@ struct MoeFCGemm {
|
||||
weight_scales = args.weight_scales;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
quant_method = args.quant_method;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -358,7 +381,7 @@ struct MoeFCGemm {
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
if (platform::is_same<uint8_t, ElementB>::value ||
|
||||
if (args.quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
|
||||
platform::is_same<uint4b_t, ElementB>::value) {
|
||||
if (args.weight_scales == nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
@@ -394,6 +417,7 @@ struct MoeFCGemm {
|
||||
|
||||
template <typename dummy>
|
||||
struct KernelRunner<true, dummy> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
@@ -401,12 +425,14 @@ struct MoeFCGemm {
|
||||
// These types shadow the type-level definitions and support the ability
|
||||
// to implement a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
static constexpr int kInterleave =
|
||||
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(
|
||||
@@ -435,6 +461,7 @@ struct MoeFCGemm {
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// threadblock_offset of C
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
@@ -450,6 +477,7 @@ struct MoeFCGemm {
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
|
||||
}
|
||||
|
||||
// begin address offset for A for current tile
|
||||
ElementA* ptr_A =
|
||||
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
@@ -463,14 +491,17 @@ struct MoeFCGemm {
|
||||
: gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
@@ -610,6 +641,381 @@ struct MoeFCGemm {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled
|
||||
/// for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to //
|
||||
/// NOLINT perform
|
||||
>
|
||||
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_> {
|
||||
public:
|
||||
using Base = MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_>;
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = false;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = typename Base::MapArguments;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and
|
||||
// complex conjugate operation. Must interact with the 'kTransposed' notion.
|
||||
static_assert(!kTransposed, "Transpose problem not supported");
|
||||
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC =
|
||||
Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor = typename Base::ProblemVisitor;
|
||||
using Arguments = typename Base::Arguments;
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : Base::Params {
|
||||
// Extra arguments for wint2.0
|
||||
uint8_t* local_scale;
|
||||
float* code_scale;
|
||||
float* code_zp;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
int tile_count = 0) // NOLINT
|
||||
: Base::Params(args, workspace, tile_count),
|
||||
local_scale(args.local_scale),
|
||||
code_scale(args.code_scale),
|
||||
code_zp(args.code_zp) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
int tile_count = 0) {
|
||||
Base::update(args, workspace, tile_count);
|
||||
|
||||
local_scale = args.local_scale;
|
||||
code_scale = args.code_scale;
|
||||
code_zp = args.code_zp;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMoeFCGemm() {}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
if (args.quant_method != WintQuantMethod::kWeightOnlyInt2) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"Wint2xMoeFCGemm::can_implement() - only support weight_only_int2!");
|
||||
return Status::kInvalid;
|
||||
} else if (args.weight_scales == nullptr || args.local_scale == nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
// The dummy template parameter is not used and exists so that we can compile
|
||||
// this code using a standard earlier than C++17. Prior to C++17, fully
|
||||
// specialized templates HAD to exists in a namespace
|
||||
template <WintQuantMethod QuantMethod, bool B, typename dummy = void>
|
||||
struct KernelRunner {
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
};
|
||||
|
||||
template <WintQuantMethod QuantMethod, typename dummy>
|
||||
struct KernelRunner<QuantMethod, true, dummy> {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
|
||||
QuantArguments quant_args;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
|
||||
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
}
|
||||
return quant_args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability
|
||||
// to implement a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
using QuantElementB = typename WeightQuantTraits::WeightType;
|
||||
using MmaElementB = typename WeightQuantTraits::MmaWeightType;
|
||||
|
||||
static constexpr int kInterleave =
|
||||
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(
|
||||
platform::is_same<LayoutB, layout::RowMajor>::value &&
|
||||
kInterleave == 1 ||
|
||||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
|
||||
kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// LayoutB should be RowMajor
|
||||
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
// wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64.
|
||||
const int64_t quant_gemm_k = WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
|
||||
int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits<QuantElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile()) {
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// threadblock_offset of C
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
0);
|
||||
|
||||
// begin address offset for weight_scale.
|
||||
ElementScale* weight_scale_ptr =
|
||||
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on
|
||||
// the transpose
|
||||
int64_t rows_to_jump = 0;
|
||||
|
||||
if (params.problem_visitor.total_rows < 0) {
|
||||
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
} else {
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
|
||||
}
|
||||
|
||||
// begin address offset for A for current tile
|
||||
ElementA* ptr_A =
|
||||
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
// begin address offset for B for current problem_idx, totally num_experts problems
|
||||
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
|
||||
problem_idx * bytes_per_expert_matrix; // NOLINT
|
||||
|
||||
typename LayoutB::LongIndex ldm_B =
|
||||
platform::is_same<layout::RowMajor, LayoutB>::value
|
||||
? gemm_n
|
||||
: gemm_k * kInterleave;
|
||||
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
|
||||
|
||||
MmaElementB* smem_unzip_B_ptr = nullptr;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
|
||||
}
|
||||
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
|
||||
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
|
||||
byte_ptr_B,
|
||||
ldm_B,
|
||||
extent_B,
|
||||
tb_offset_B,
|
||||
weight_scale_ptr,
|
||||
tb_offset_scale,
|
||||
quant_args);
|
||||
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(LayoutA(ldm_A),
|
||||
ptr_A,
|
||||
{problem_size.m(), problem_size.k()},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
|
||||
ptr_B,
|
||||
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
|
||||
thread_idx,
|
||||
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations =
|
||||
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the
|
||||
// previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C =
|
||||
params.ptr_C ? reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n : nullptr;
|
||||
ElementC* ptr_D =
|
||||
reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
LayoutC layout_C(0);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C,
|
||||
ptr_C,
|
||||
problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D,
|
||||
ptr_D,
|
||||
problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the
|
||||
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
|
||||
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -15,16 +15,22 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
template <typename T, /*The type used for activations/scales/compute*/
|
||||
typename WeightType /* The type for the MoE weights */>
|
||||
typename WeightQuantTraits /* The quant traits for the MoE weights */>
|
||||
class MoeGemmRunner {
|
||||
public:
|
||||
using WeightType = typename WeightQuantTraits::WeightType;
|
||||
using Arguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
MoeGemmRunner();
|
||||
|
||||
void moe_gemm_bias_act(const T* A,
|
||||
@@ -38,6 +44,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
std::string activation_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
@@ -51,6 +58,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
cudaStream_t stream);
|
||||
|
||||
private:
|
||||
@@ -65,6 +73,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr);
|
||||
@@ -81,6 +90,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
cudaStream_t stream);
|
||||
|
||||
private:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
|
||||
@@ -22,7 +22,8 @@
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
|
||||
#include "helper.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16,
|
||||
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
@@ -21,7 +21,9 @@
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16,
|
||||
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -22,8 +22,9 @@
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16,
|
||||
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
|
||||
} // namespace phi
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<half, half>;
|
||||
template class MoeGemmRunner<half,
|
||||
cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kNone>>;
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
|
||||
#include "helper.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<
|
||||
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
|
||||
|
||||
} // namespace phi
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<half, cutlass::uint4b_t>;
|
||||
template class MoeGemmRunner<
|
||||
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<half, uint8_t>;
|
||||
template class MoeGemmRunner<
|
||||
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -24,9 +24,10 @@
|
||||
#include <math.h>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
@@ -35,8 +36,11 @@
|
||||
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
|
||||
@@ -48,17 +52,47 @@
|
||||
#include "helper.h"
|
||||
|
||||
namespace phi {
|
||||
// ============================= Variable batched Gemm things
|
||||
// ===========================
|
||||
|
||||
template <typename MixedGemmArchTraits, cutlass::WintQuantMethod Method>
|
||||
struct CutlassLayoutB {
|
||||
using Type = typename MixedGemmArchTraits::LayoutB;
|
||||
};
|
||||
|
||||
template <typename MixedGemmArchTraits>
|
||||
struct CutlassLayoutB<MixedGemmArchTraits, cutlass::WintQuantMethod::kNone> {
|
||||
using Type = cutlass::layout::RowMajor;
|
||||
};
|
||||
|
||||
template <typename BaseGemmKernel, typename Arch, cutlass::WintQuantMethod Method>
|
||||
struct CutlassGemmKernel {
|
||||
using Type =
|
||||
cutlass::gemm::kernel::MoeFCGemm<typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
};
|
||||
|
||||
template <typename BaseGemmKernel, typename Arch>
|
||||
struct CutlassGemmKernel<BaseGemmKernel, Arch, cutlass::WintQuantMethod::kWeightOnlyInt2> {
|
||||
using Type =
|
||||
cutlass::gemm::kernel::Wint2xMoeFCGemm<typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
};
|
||||
|
||||
// ======================= Variable batched Gemm things =======================
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
int Stages>
|
||||
void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -67,6 +101,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
const int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
@@ -86,44 +121,26 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
"Specialized for half, float");
|
||||
#endif
|
||||
|
||||
using WeightType = typename WeightQuantTraits::WeightType;
|
||||
|
||||
static_assert(
|
||||
cutlass::platform::is_same<T, WeightType>::value ||
|
||||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
|
||||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
||||
"");
|
||||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value ||
|
||||
cutlass::platform::is_same<WeightType, uint16_t>::value,
|
||||
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to
|
||||
// cutlass::half_t if necessary.
|
||||
using ElementType_ = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<T, half>::value,
|
||||
cutlass::half_t,
|
||||
T>::type;
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
using ElementType = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t,
|
||||
ElementType_>::type;
|
||||
#else
|
||||
using ElementType = ElementType_;
|
||||
#endif
|
||||
|
||||
using CutlassWeightType_ = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<WeightType, half>::value,
|
||||
cutlass::half_t,
|
||||
WeightType>::type;
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
using CutlassWeightType = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t,
|
||||
CutlassWeightType_>::type;
|
||||
#else
|
||||
using CutlassWeightType = CutlassWeightType_;
|
||||
#endif
|
||||
using ElementType = typename cutlass::CutlassDataType<T>::Type;
|
||||
using CutlassWeightType = typename cutlass::CutlassDataType<typename WeightQuantTraits::WeightType>::Type;
|
||||
using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType;
|
||||
using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType;
|
||||
|
||||
// We need separate config for each architecture since we will target
|
||||
// different tensorcore instructions. For float, we do not target TCs.
|
||||
using MixedGemmArchTraits = cutlass::gemm::kernel::
|
||||
MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
||||
MixedGemmArchTraits<ElementType, CutlassMmaKernelType, arch>;
|
||||
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
||||
|
||||
using EpilogueOp = typename Epilogue<ElementType,
|
||||
@@ -132,13 +149,13 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
EpilogueTag>::Op;
|
||||
|
||||
// Finally, set up the kernel.
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
ElementType,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessA,
|
||||
CutlassWeightType,
|
||||
typename MixedGemmArchTraits::LayoutB,
|
||||
CutlassMmaKernelType,
|
||||
typename CutlassLayoutB<MixedGemmArchTraits, WeightQuantTraits::kQuantMethod>::Type,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessB,
|
||||
ElementType,
|
||||
@@ -155,14 +172,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma,
|
||||
typename GemmKernel_::Epilogue,
|
||||
typename GemmKernel_::ThreadblockSwizzle,
|
||||
arch, // Ensure top level arch is used
|
||||
// for dispatch
|
||||
GemmKernel_::kGroupScheduleMode>;
|
||||
|
||||
using GemmKernel = typename CutlassGemmKernel<BaseGemmKernel, arch, WeightQuantTraits::kQuantMethod>::Type;
|
||||
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
if (kernel_occupancy != nullptr) {
|
||||
@@ -181,19 +191,32 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f),
|
||||
ElementAccumulator(0.f));
|
||||
|
||||
const uint8_t* local_scale_B = nullptr;
|
||||
const float* code_scale_B = nullptr;
|
||||
const float* code_zp_B = nullptr;
|
||||
if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
local_scale_B = quant_args_B.local_scale_ptr;
|
||||
code_scale_B = quant_args_B.code_scale_ptr;
|
||||
code_zp_B = quant_args_B.code_zp_ptr;
|
||||
}
|
||||
|
||||
typename GemmGrouped::Arguments args(
|
||||
num_experts,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<const ElementType*>(A),
|
||||
reinterpret_cast<const CutlassWeightType*>(B),
|
||||
reinterpret_cast<const CutlassMmaWeightType*>(B),
|
||||
reinterpret_cast<const ElementType*>(weight_scales),
|
||||
reinterpret_cast<const ElementType*>(biases),
|
||||
reinterpret_cast<ElementType*>(C),
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
gemm_n,
|
||||
gemm_k);
|
||||
gemm_k,
|
||||
WeightQuantTraits::kQuantMethod,
|
||||
local_scale_B,
|
||||
code_scale_B,
|
||||
code_zp_B);
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
@@ -222,7 +245,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
@@ -231,7 +254,7 @@ template <typename T,
|
||||
typename Enable = void>
|
||||
struct dispatch_stages {
|
||||
static void dispatch(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -240,6 +263,7 @@ struct dispatch_stages {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
@@ -253,20 +277,20 @@ struct dispatch_stages {
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
struct dispatch_stages<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
arch,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
2> {
|
||||
static void dispatch(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -275,12 +299,13 @@ struct dispatch_stages<T,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
generic_moe_gemm_kernelLauncher<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
arch,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
@@ -295,6 +320,7 @@ struct dispatch_stages<T,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
gemm_config,
|
||||
multi_processor_count,
|
||||
stream,
|
||||
@@ -303,13 +329,13 @@ struct dispatch_stages<T,
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
int Stages>
|
||||
struct dispatch_stages<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
cutlass::arch::Sm80,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
@@ -317,7 +343,7 @@ struct dispatch_stages<T,
|
||||
Stages,
|
||||
typename std::enable_if<(Stages > 2)>::type> {
|
||||
static void dispatch(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -326,12 +352,13 @@ struct dispatch_stages<T,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
generic_moe_gemm_kernelLauncher<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
cutlass::arch::Sm80,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
@@ -346,6 +373,7 @@ struct dispatch_stages<T,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
gemm_config,
|
||||
multi_processor_count,
|
||||
stream,
|
||||
@@ -354,13 +382,13 @@ struct dispatch_stages<T,
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
void dispatch_gemm_config(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -369,6 +397,7 @@ void dispatch_gemm_config(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
@@ -376,7 +405,7 @@ void dispatch_gemm_config(const T* A,
|
||||
#define dispatch_stages_macro(STAGE) \
|
||||
case STAGE: \
|
||||
dispatch_stages<T, \
|
||||
WeightType, \
|
||||
WeightQuantTraits, \
|
||||
arch, \
|
||||
EpilogueTag, \
|
||||
ThreadblockShape, \
|
||||
@@ -391,6 +420,7 @@ void dispatch_gemm_config(const T* A,
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
multi_processor_count, \
|
||||
stream, \
|
||||
@@ -414,7 +444,7 @@ void dispatch_gemm_config(const T* A,
|
||||
case CutlassTileConfig:: \
|
||||
CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \
|
||||
dispatch_gemm_config<T, \
|
||||
WeightType, \
|
||||
WeightQuantTraits, \
|
||||
arch, \
|
||||
EpilogueTag, \
|
||||
cutlass::gemm::GemmShape<AA, BB, CC>, \
|
||||
@@ -425,10 +455,11 @@ void dispatch_gemm_config(const T* A,
|
||||
biases, \
|
||||
C, \
|
||||
total_rows_before_expert, \
|
||||
total_rows, \
|
||||
total_rows, \
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
multi_processor_count, \
|
||||
stream, \
|
||||
@@ -438,14 +469,14 @@ void dispatch_gemm_config(const T* A,
|
||||
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
|
||||
// This overload is only enabled when T == WeightType.
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value &&
|
||||
std::is_same<T, WeightType>::value>::type* =
|
||||
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -454,6 +485,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
@@ -474,7 +506,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] Config is invalid for same "
|
||||
"type MoE tensorop GEMM.");
|
||||
"type MoE tensorop GEMM for FP16/BF16.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -483,14 +515,14 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
// Overload for quantize MoE GEMMs. We disable some warp configs here since they
|
||||
// will not be used and we can improve compile time
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, WeightType>::value>::type* =
|
||||
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -499,28 +531,34 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
|
||||
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
|
||||
"already been set by heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
|
||||
"mixed type tensorop GEMM.");
|
||||
break;
|
||||
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
|
||||
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
|
||||
"already been set by heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
|
||||
"mixed type tensorop GEMM for sm70.");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
|
||||
}
|
||||
} else {
|
||||
switch (gemm_config.tile_config) {
|
||||
@@ -555,12 +593,12 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
|
||||
template <
|
||||
typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -569,6 +607,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
@@ -594,8 +633,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
MoeGemmRunner<T, WeightQuantTraits>::MoeGemmRunner() {
|
||||
int device{-1};
|
||||
check_cuda_error(cudaGetDevice(&device));
|
||||
sm_ = getSMVersion();
|
||||
@@ -603,11 +642,11 @@ MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
|
||||
&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::dispatch_to_arch<EpilogueTag>(
|
||||
const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -616,11 +655,12 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream,
|
||||
int* occupancy) {
|
||||
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
|
||||
dispatch_moe_gemm_to_cutlass<T, WeightType, ARCH, EpilogueTag>( \
|
||||
dispatch_moe_gemm_to_cutlass<T, WeightQuantTraits, ARCH, EpilogueTag>( \
|
||||
A, \
|
||||
B, \
|
||||
weight_scales, \
|
||||
@@ -631,6 +671,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
sm_, \
|
||||
multi_processor_count_, \
|
||||
@@ -648,25 +689,28 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
|
||||
const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t tune_total_rows,
|
||||
int64_t actual_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
cudaStream_t stream) {
|
||||
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
||||
static constexpr bool is_weight_only = !std::is_same<T, typename WeightQuantTraits::WeightType>::value;
|
||||
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs =
|
||||
get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true);
|
||||
|
||||
static constexpr int warm_time = 5;
|
||||
static constexpr int test_time = 10;
|
||||
auto& gemmConfigManager = GemmConfigManager::Instance();
|
||||
@@ -676,17 +720,19 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n, gemm_k, GemmType::MOEGEMM, dtype, wdtype, num_experts};
|
||||
CutlassGemmConfig chosen_config;
|
||||
auto chosen_config_optional =
|
||||
gemmConfigManager.getBestConfig(gemmId, tune_total_rows);
|
||||
gemmConfigManager.getBestConfig(gemmId, actual_total_rows);
|
||||
if (chosen_config_optional != std::nullopt) {
|
||||
chosen_config = chosen_config_optional.value();
|
||||
} else {
|
||||
size_t best_id = -1;
|
||||
float best_time = std::numeric_limits<float>::max();
|
||||
CutlassGemmConfig best_config;
|
||||
int profile_total_rows =
|
||||
std::min(gemmConfigManager.nextPowerOfTwo(tune_total_rows),
|
||||
std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows),
|
||||
gemmConfigManager.getMaxProfileM());
|
||||
bool find_one = false;
|
||||
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
||||
size_t num_candidate_configs_size = candidate_configs.size();
|
||||
for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) {
|
||||
try {
|
||||
for (int i = 0; i < warm_time; i++) {
|
||||
dispatch_to_arch<EpilogueTag>(A,
|
||||
@@ -699,6 +745,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
candidate_configs[ii],
|
||||
stream);
|
||||
}
|
||||
@@ -719,6 +766,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
candidate_configs[ii],
|
||||
stream);
|
||||
}
|
||||
@@ -728,7 +776,9 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
|
||||
check_cuda_error(cudaEventDestroy(start));
|
||||
check_cuda_error(cudaEventDestroy(stop));
|
||||
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
|
||||
if (elapsed < best_time) {
|
||||
best_id = ii;
|
||||
best_time = elapsed;
|
||||
best_config = candidate_configs[ii];
|
||||
}
|
||||
@@ -739,6 +789,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
}
|
||||
}
|
||||
if (find_one) {
|
||||
//std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl;
|
||||
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
|
||||
chosen_config = best_config;
|
||||
} else {
|
||||
@@ -756,23 +807,25 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
chosen_config,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm_bias_act(
|
||||
const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t tune_total_rows,
|
||||
int64_t actual_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
std::string activation_type,
|
||||
cudaStream_t stream) {
|
||||
if (activation_type == "none") {
|
||||
@@ -784,10 +837,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
|
||||
C,
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
tune_total_rows,
|
||||
actual_total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
stream);
|
||||
} else {
|
||||
run_gemm<EpilogueOpNoBias>(A,
|
||||
@@ -797,27 +851,30 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
|
||||
C,
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
tune_total_rows,
|
||||
actual_total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
|
||||
const WeightType* B,
|
||||
const T* weight_scales,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t tune_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
cudaStream_t stream) {
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t actual_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
cudaStream_t stream) {
|
||||
run_gemm<EpilogueOpNoBias>(A,
|
||||
B,
|
||||
weight_scales,
|
||||
@@ -825,10 +882,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
|
||||
C,
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
tune_total_rows,
|
||||
actual_total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,102 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "helper.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_helper.h"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int>
|
||||
get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) {
|
||||
int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1];
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
phi::Place device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape,
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(device);
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
auto stream = paddle::GetCurrentCUDAStream(device)->raw_stream();
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace->ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = StrideC;
|
||||
using StrideAux = StrideC;
|
||||
|
||||
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
StrideB b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
StrideC c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
StrideD d_stride =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
StrideAux aux_stride = d_stride;
|
||||
|
||||
auto a_ptr = static_cast<ElementAB *>(const_cast<void *>(a.data()));
|
||||
auto b_ptr = static_cast<ElementAB *>(const_cast<void *>(b.data()));
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD *>(const_cast<void *>(out.data()));
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, d_stride};
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.place(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
} // namespace fastdeploy::c3x
|
||||
149
custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh
Normal file
149
custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh
Normal file
@@ -0,0 +1,149 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_helper.h"
|
||||
#include "helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
Epilogues defined in,
|
||||
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
|
||||
must contain a public type named EVTCompute of type Sm90EVT, as well as a
|
||||
static prepare_args function that constructs an EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
|
||||
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm100 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,27 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::optional<paddle::Tensor> const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
*azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,34 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func>
|
||||
void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b, paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias,
|
||||
Fp8Func fp8_func, Int8Func int8_func) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1];
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) {
|
||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
PD_CHECK(false, "Int8 not supported for this architecture");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"No kernel for this combination of input dtypes is implemented."));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,28 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,125 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (128, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.dims()[0];
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...epilogue_args) {
|
||||
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,29 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,168 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_default {
|
||||
// For M > 128 and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.dims()[0];
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...epilogue_args) {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
200
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu
Normal file
200
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu
Normal file
@@ -0,0 +1,200 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
|
||||
|
||||
#include "helper.h"
|
||||
#include <stddef.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace fastdeploy;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.dtype() == paddle::DataType::INT8) {
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
223
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cuh
Normal file
223
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cuh
Normal file
@@ -0,0 +1,223 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
|
||||
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include "helper.h"
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "cutlass_helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
Epilogues defined in,
|
||||
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm75_to_sm80 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm80_to_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm89_to_sm90 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
|
||||
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
FP8MathOperator>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
||||
float, cutlass::layout::RowMajor, AlignmentCD,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
MainLoopStages, Operator,
|
||||
1 /* epilogue stages */
|
||||
>::GemmKernel>;
|
||||
// clang-format on
|
||||
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_caller(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.dims()[0];
|
||||
int32_t n = b.dims()[0];
|
||||
int32_t k = a.dims()[1];
|
||||
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||
|
||||
int64_t lda = a.strides()[0];
|
||||
int64_t ldb = b.strides()[0];
|
||||
int64_t ldc = out.strides()[0];
|
||||
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data());
|
||||
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
using Epilogue = typename Gemm::Epilogue;
|
||||
auto evt_args =
|
||||
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
typename Gemm::Op::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||
problem_size, // problem size
|
||||
1, // batch count
|
||||
epilogue_args,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldc};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(a.place());
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
auto stream = a.stream();
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace->ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
inline void fallback_cutlass_gemm_caller(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
// the FallbackGemm instead.
|
||||
static const int max_shared_mem_per_block_opt_in =
|
||||
get_cuda_max_shared_memory_per_block_opt_in(0);
|
||||
|
||||
size_t const gemm_shared_mem_size =
|
||||
sizeof(typename Gemm::KernelType::SharedStorage);
|
||||
size_t const fallback_gemm_shared_mem_size =
|
||||
sizeof(typename FallbackGemm::KernelType::SharedStorage);
|
||||
|
||||
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
PD_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,125 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM75 based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (256, inf]
|
||||
// - M in (64, 128]
|
||||
// Shared memory required by this Gemm 32768
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M256 {
|
||||
// M in (128, 256]
|
||||
// Shared memory required by this Gemm 65536
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M64 {
|
||||
// M in (32, 64]
|
||||
// Shared memory required by this Gemm 49152
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M32 {
|
||||
// M in [1, 32]
|
||||
// Shared memory required by this Gemm 49152
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm75_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM256 =
|
||||
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm75_config_default has the least shared-memory requirements.
|
||||
using FallbackGemm = Cutlass2xGemmDefault;
|
||||
|
||||
uint32_t const m = a.dims()[0];;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 32) {
|
||||
// M in [1, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,141 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM80 based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (128, inf)
|
||||
// - M in (64, 128] and N >= 8192
|
||||
// Shared Memory required by this Gemm - 81920 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M64 {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (32, 64]
|
||||
// - M in (64, 128] and N < 8192
|
||||
// Shared Memory required by this Gemm - 122880 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M32 {
|
||||
// M in (16, 32]
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M16 {
|
||||
// M in [1, 16]
|
||||
// Shared Memory required by this Gemm - 51200 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm80_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128BigN =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128SmallN =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM16 =
|
||||
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm80_config_M16 has the least shared-memory requirement. However,
|
||||
// based on some profiling, we select sm80_config_M32 as a better alternative
|
||||
// performance wise.
|
||||
using FallbackGemm =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const m = a.dims()[0];;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
uint32_t const n = out.dims()[1];;
|
||||
bool const small_n = n < 8192;
|
||||
if (small_n) {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
|
||||
FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
// M in (128, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,370 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_fp8_fallback_gemm {
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5,
|
||||
FP8MathOperator>;
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8196) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
static const int32_t MainLoopStages = 5;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 24576) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
uint32_t const m = a.dims()[0];;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,355 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_int8_fallback_gemm {
|
||||
// Shared mem requirement : 61440
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
static int32_t const MainLoopStages = 5;
|
||||
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
struct sm89_int8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[0];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
uint32_t const m = a.dims()[0];
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,37 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
|
||||
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
fastdeploy::cutlass_scaled_mm_sm90_fp8,
|
||||
fastdeploy::cutlass_scaled_mm_sm90_int8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
fastdeploy::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
224
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu
Normal file
224
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,224 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
|
||||
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include <iostream>
|
||||
|
||||
void cutlass_scaled_mm_sm75(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b, paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
// Checks for conformality
|
||||
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
|
||||
c.dims().size() == 2);
|
||||
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
|
||||
b.dims()[0] == c.dims()[1]);
|
||||
|
||||
// Check for strides and alignment
|
||||
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
|
||||
PD_CHECK(b.strides()[1] == 1); // Column-major
|
||||
PD_CHECK(c.strides()[0] % 16 == 0 &&
|
||||
b.strides()[0] % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous() &&
|
||||
bias->dims().size() == 1);
|
||||
}
|
||||
|
||||
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: %d",
|
||||
version_num));
|
||||
}
|
||||
|
||||
void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
|
||||
c.dims().size() == 2);
|
||||
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
|
||||
b.dims()[0] == c.dims()[1]);
|
||||
PD_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]);
|
||||
PD_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0]);
|
||||
|
||||
// Check for strides and alignment
|
||||
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
|
||||
PD_CHECK(b.strides()[1] == 1); // Column-major
|
||||
PD_CHECK(c.strides()[0] % 16 == 0 &&
|
||||
b.strides()[0] % 16 == 0); // 16 Byte Alignment
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
PD_CHECK(azp->numel() == a.dims()[0] && azp->is_contiguous());
|
||||
}
|
||||
PD_CHECK(azp_adj.numel() == b.dims()[0] && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
PD_CHECK(azp_adj.dtype() == paddle::DataType::INT32);
|
||||
PD_CHECK(!azp || azp->dtype() == paddle::DataType::INT32);
|
||||
PD_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
PD_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: %d",
|
||||
version_num));
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(cutlass_scaled_mm)
|
||||
.Inputs({"c", "a", "b", "a_scales", "b_scales", paddle::Optional("bias")})
|
||||
.Outputs({"c_out"})
|
||||
.SetInplaceMap({{"c", "c_out"}})
|
||||
.SetKernelFn(PD_KERNEL(CutlassScaledMm));
|
||||
|
||||
PD_BUILD_STATIC_OP(cutlass_scaled_mm_azp)
|
||||
.Inputs({"c", "a", "b", "a_scales", "b_scales", "azp_adj", paddle::Optional("azp"), paddle::Optional("bias")})
|
||||
.Outputs({"c_out"})
|
||||
.SetInplaceMap({{"c", "c_out"}})
|
||||
.SetKernelFn(PD_KERNEL(CutlassScaledMmAzp));
|
||||
Reference in New Issue
Block a user