Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -17,15 +17,12 @@
#include "paddle/phi/core/memory/memcpy.h"
template <int THREADBLOCK_SIZE>
__global__ void GetMaxLenKernel(const int *seq_lens,
const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged,
const int *seq_mapping,
const int *system_lens,
int *max_lens,
const int batch_size) {
__global__ void
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged, const int *seq_mapping,
const int *system_lens, int *max_lens, const int batch_size) {
const int tid = threadIdx.x;
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
@@ -41,43 +38,61 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
int max_dec_len_without_system_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
const int seq_len_this_time = seq_lens_this_time[i];
max_len_this_time_this_thread = max(seq_len_this_time,
max_len_this_time_this_thread);
max_len_encoder_this_thread = max(seq_lens_encoder[i],
max_len_encoder_this_thread);
max_len_this_time_this_thread =
max(seq_len_this_time, max_len_this_time_this_thread);
max_len_encoder_this_thread =
max(seq_lens_encoder[i], max_len_encoder_this_thread);
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
if (seq_len_this_time <= 0) continue;
if (seq_len_this_time <= 0)
continue;
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
max_len_this_thread = max(seq_lens[i] + seq_len_this_time,
max_len_this_thread);
max_just_dec_len_this_thread = max(max_just_dec_len_this_thread,
max_just_dec_len_now);
max_len_this_thread =
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
max_just_dec_len_this_thread =
max(max_just_dec_len_this_thread, max_just_dec_len_now);
if (system_lens) {
const int real_bid = seq_mapping[i];
const int system_len_now = system_lens[real_bid];
max_system_len_this_thread = max(max_system_len_this_thread, system_len_now);
max_dec_len_without_system_this_thread = max(max_dec_len_without_system_this_thread,
max_just_dec_len_now - system_len_now);
max_system_len_this_thread =
max(max_system_len_this_thread, system_len_now);
max_dec_len_without_system_this_thread =
max(max_dec_len_without_system_this_thread,
max_just_dec_len_now - system_len_now);
}
}
if (system_lens) {
for (int i = tid; i < batch_size; i += blockDim.x) {
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
if (ori_seq_len_this_time <= 0) continue;
const int max_just_dec_merged_len_this_time_now = seq_lens_encoder_merged[i] > 0 ?
0 : ori_seq_len_this_time;
max_just_dec_merged_len_this_time_this_thread = max(max_just_dec_merged_len_this_time_this_thread,
max_just_dec_merged_len_this_time_now);
if (ori_seq_len_this_time <= 0)
continue;
const int max_just_dec_merged_len_this_time_now =
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
max_just_dec_merged_len_this_time_this_thread =
max(max_just_dec_merged_len_this_time_this_thread,
max_just_dec_merged_len_this_time_now);
}
}
int total_max_len_this_time = BlockReduce(temp_storage).Reduce(max_len_this_time_this_thread, MaxOp<int>());
int total_max_len_encoder = BlockReduce(temp_storage).Reduce(max_len_encoder_this_thread, MaxOp<int>());
int total_max_len_decoder = BlockReduce(temp_storage).Reduce(max_len_decoder_this_thread, MaxOp<int>());
int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total_just_dec = BlockReduce(temp_storage).Reduce(max_just_dec_len_this_thread, MaxOp<int>());
int total_just_dec_merged = BlockReduce(temp_storage).Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
int total_system_len = BlockReduce(temp_storage).Reduce(max_system_len_this_thread, MaxOp<int>());
int total_dec_len_without_system = BlockReduce(temp_storage).Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
int total_max_len_this_time =
BlockReduce(temp_storage)
.Reduce(max_len_this_time_this_thread, MaxOp<int>());
int total_max_len_encoder =
BlockReduce(temp_storage)
.Reduce(max_len_encoder_this_thread, MaxOp<int>());
int total_max_len_decoder =
BlockReduce(temp_storage)
.Reduce(max_len_decoder_this_thread, MaxOp<int>());
int total =
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total_just_dec = BlockReduce(temp_storage)
.Reduce(max_just_dec_len_this_thread, MaxOp<int>());
int total_just_dec_merged =
BlockReduce(temp_storage)
.Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
int total_system_len = BlockReduce(temp_storage)
.Reduce(max_system_len_this_thread, MaxOp<int>());
int total_dec_len_without_system =
BlockReduce(temp_storage)
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
if (tid == 0) {
max_lens[0] = total_max_len_this_time;
max_lens[1] = total_max_len_encoder;
@@ -90,30 +105,22 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
}
}
void GetMaxLen(const paddle::Tensor& seq_lens_tensor,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
paddle::Tensor &max_len_tensor,
const int batch_size) {
void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
paddle::Tensor &max_len_tensor, const int batch_size) {
constexpr int blockSize = 1024;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, seq_lens_encoder.stream()>>>(
seq_lens_tensor.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
nullptr,
nullptr,
nullptr,
nullptr,
max_len_tensor.data<int>(),
batch_size);
seq_lens_tensor.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(), nullptr, nullptr, nullptr, nullptr,
max_len_tensor.data<int>(), batch_size);
}
__global__ void split_q_block(const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_encoder,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
const int bsz,
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
const int num_rows_per_block,
const int group_size) {
if (threadIdx.x == 0) {
@@ -124,8 +131,7 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
seq_len = 0;
}
const int loop_times =
div_up(seq_len * group_size, num_rows_per_block);
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
@@ -136,14 +142,12 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
}
}
__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_encoder,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
const int bsz,
const int pad_len,
const int num_row_per_block) {
__global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
const int pad_len, const int num_row_per_block) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
@@ -165,50 +169,46 @@ __global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
}
template <int THREADBLOCK_SIZE>
__global__ void get_max_len_kv_ernel(int* max_seq_lens_out,
const int* seq_lens_this_time,
const int* seq_lens_decoder,
const int batch_size) {
__global__ void
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
const int *seq_lens_decoder, const int batch_size) {
const int tid = threadIdx.x;
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int max_len_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
if (seq_lens_decoder[i] == 0) continue;
max_len_this_thread = max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
if (seq_lens_decoder[i] == 0)
continue;
max_len_this_thread =
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
}
int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total =
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
if (tid == 0) {
*max_seq_lens_out = total;
}
}
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cum_offsets,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream();
int bsz = cum_offsets.shape()[0];
auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(
seq_lens_decoder,
seq_lens_this_time,
seq_lens_encoder,
max_len_tensor,
bsz);
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor, bsz);
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time,
// max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
// max_just_dec_merged_len_this_time, max_system_len,
// max_just_dec_len_without_system
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
auto max_len_cpu_ptr = max_len_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
@@ -229,67 +229,67 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor decoder_batch_ids;
paddle::Tensor decoder_tile_ids_per_batch;
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
bsz
);
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu =
max_len_kv.copy_to(paddle::CPUPlace(), false);
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
seq_lens_decoder.data<int>(),
// sequence_lengths->data<int>(),
seq_lens_encoder.data<int>(),
kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(),
kv_num_blocks_x.data<int>(),
bsz,
block_size,
block_size
);
seq_lens_decoder.data<int>(),
// sequence_lengths->data<int>(),
seq_lens_encoder.data<int>(), kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
const uint32_t encoder_max_tile_size_per_bs_q = div_up(
(max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
nullptr,
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
encoder_batch_ids.data<int>(),
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(),
bsz,
encoder_block_shape_q,
group_size);
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
}
if (max_just_dec_len_this_time > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
@@ -297,24 +297,26 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}
return {encoder_batch_ids,
@@ -331,28 +333,22 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cum_offsets_dtype) {
return {paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype,
const paddle::DataType &cum_offsets_dtype) {
return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cum_offsets_shape) {
const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape,
const std::vector<int64_t> &cum_offsets_shape) {
std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape,
@@ -369,9 +365,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time",
"cum_offsets"})
.Outputs({paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
@@ -382,12 +376,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
paddle::Optional("decoder_batch_ids"),
paddle::Optional("decoder_tile_ids_per_batch"),
paddle::Optional("decoder_num_blocks"),
paddle::Optional("max_len_kv"),
"set_max_lengths"})
.Attrs({"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
paddle::Optional("max_len_kv"), "set_max_lengths"})
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
"group_size: int", "block_size: int",
"decoder_step_token_num: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))

View File

@@ -337,6 +337,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (deal_each_time == 64) { \
constexpr size_t DEAL_EACH_TIME = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the deal_each_time", deal_each_time); \
}
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
@@ -346,6 +348,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (num_threads == 256) { \
constexpr size_t NUM_THREADS = 256; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_threads", num_threads); \
}
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
@@ -376,6 +380,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \

View File

@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/extension.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
// 自定义异常类用于处理CUDA错误
@@ -125,45 +125,40 @@ paddle::Tensor FusedExpertMoeFunc(
const bool norm_topk_prob, const bool group_moe);
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode);
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const bool topk_only_mode);
std::vector<paddle::Tensor>
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
const paddle::optional<paddle::Tensor> &bias,
const int moe_topk, const bool apply_norm_weight,
const bool enable_softmax_top_k_fused);
const paddle::optional<paddle::Tensor> &bias,
const int moe_topk, const bool apply_norm_weight,
const bool enable_softmax_top_k_fused);
std::vector<paddle::Tensor> MoERedundantTopKSelectKernel(
const paddle::Tensor& gating_logits,
const paddle::Tensor& expert_id_to_ep_rank_array,
const paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list,
const paddle::optional<paddle::Tensor>& bias,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one);
std::vector<paddle::Tensor>
MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
const paddle::Tensor &expert_id_to_ep_rank_array,
const paddle::Tensor &expert_in_rank_num_list,
paddle::Tensor &tokens_per_expert_stats_list,
const paddle::optional<paddle::Tensor> &bias,
const int moe_topk, const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one);
std::vector<paddle::Tensor>
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
const paddle::Tensor &topk_weights,
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
const std::vector<int> &token_nums_per_expert,
const int token_nums_this_rank,
const std::string &moe_quant_type);
const paddle::Tensor &topk_weights,
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
const std::vector<int> &token_nums_per_expert,
const int token_nums_this_rank,
const std::string &moe_quant_type);
std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor &input, const paddle::Tensor &scale,
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
const std::vector<int> &token_nums_per_expert,
const std::vector<int> &token_nums_per_expert_padded,
const int token_nums_this_rank, const int token_nums_this_rank_padded);
const paddle::Tensor &token_nums_per_expert,
const paddle::Tensor &token_nums_per_expert_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size);
@@ -180,20 +175,35 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::optional<paddle::Tensor> &ffn2_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
std::vector<std::vector<int>> GetExpertTokenNum(
const paddle::Tensor& topk_ids,
const int num_experts);
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
const int num_experts);
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor &permute_input,
const paddle::Tensor &tokens_expert_prefix_sum,
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
const paddle::optional<paddle::Tensor> &ffn1_bias,
const paddle::optional<paddle::Tensor> &ffn1_scale,
const paddle::optional<paddle::Tensor> &ffn2_scale,
const paddle::optional<paddle::Tensor> &ffn2_in_scale,
const paddle::optional<paddle::Tensor> &expert_idx_per_token,
const std::string &quant_method, const bool used_in_ep_low_latency);
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
@@ -205,19 +215,16 @@ paddle::Tensor MoeExpertReduceFunc(
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
const paddle::Tensor &seq_lens_this_time_tensor,
const paddle::Tensor &seq_lens_decoder_tensor,
const int rank,
const int num_layers);
const int rank, const int num_layers);
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id,
bool wait_flag);
paddle::Tensor DequantInt8Func(const paddle::Tensor &input,
const paddle::Tensor &out_scale,
std::string dtype);
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
const bool keep_pd_step_flag);
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
@@ -286,61 +293,121 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
std::vector<paddle::Tensor> MoEDeepGEMMPermute(
const paddle::Tensor& x,
const paddle::Tensor& topk_idx,
const int num_experts,
const int max_tokens_per_expert
);
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
const paddle::Tensor &topk_idx,
const int num_experts,
const int max_tokens_per_expert);
std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden]
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
const paddle::Tensor& topk_idx,
const paddle::Tensor& topk_weights
);
const paddle::Tensor
&ffn_out, // [num_experts, max_tokens_per_expert, hidden]
const paddle::Tensor &permute_indices_per_token, // [token_num, topk}]
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,
paddle::Tensor &token_type_ids,
paddle::Tensor &text_index,
paddle::Tensor &image_index, const bool is_scatter);
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
int64_t num_experts);
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
const paddle::Tensor& a,
const paddle::optional<paddle::Tensor>& c_or_none,
const paddle::Tensor& b_q_weight,
const paddle::Tensor& b_scales,
const paddle::optional<paddle::Tensor>& global_scale_or_none,
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
const paddle::optional<paddle::Tensor>& g_idx_or_none,
const paddle::optional<paddle::Tensor>& perm_or_none,
const paddle::Tensor& workspace,
const paddle::Tensor& sorted_token_ids,
const paddle::Tensor& expert_ids,
const paddle::Tensor& num_tokens_post_padded,
const paddle::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
const std::string& b_q_type_str,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b, paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void StaticScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input,
paddle::Tensor const &scale);
void DynamicScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input,
paddle::Tensor &scale);
void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
paddle::Tensor const &input,
paddle::Tensor &scales, float scale_ub);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum,
py::arg("topk_ids"), py::arg("num_experts"),
"get expert token num");
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
py::arg("num_experts"), "get expert token num");
/**
* moe/fused_moe/moe_redundant_topk_select.cu
* moe_redundant_topk_select
*/
m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
py::arg("expert_in_rank_num_list"),
py::arg("tokens_per_expert_stats_list"), py::arg("bias"),
py::arg("moe_topk"), py::arg("apply_norm_weight"),
py::arg("enable_softmax_top_k_fused"),
py::arg("redundant_ep_rank_num_plus_one"),
"moe export RedundantTopKSelect function");
/**
* moe/fused_moe/moe_redundant_topk_select.cu
* moe_redundant_topk_select
*/
m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
py::arg("expert_in_rank_num_list"), py::arg("tokens_per_expert_stats_list"),
py::arg("bias"), py::arg("moe_topk"), py::arg("apply_norm_weight"),
py::arg("enable_softmax_top_k_fused"), py::arg("redundant_ep_rank_num_plus_one"),
"moe export RedundantTopKSelect function");
/**
* open_shm_and_get_meta_signal.cc
* InitKVSignalPerQuery
*/
m.def("init_kv_signal_per_query", &InitKVSignalPerQuery,
py::arg("seq_lens_encoder_tensor"),
py::arg("seq_lens_this_time_tensor"),
py::arg("seq_lens_decoder_tensor"), py::arg("rank"),
py::arg("num_layers"), "init_kv_signal_per_query function");
/**
* GetOutputKVSignal
*/
m.def("get_output_kv_signal", &GetOutputKVSignal, py::arg("x"),
py::arg("rank_id"), py::arg("wait_flag"),
"get_output_kv_signal function");
/**
* open_shm_and_get_meta_signal.cc
* InitKVSingnalPerQuery
*/
m.def("init_kv_signal_per_query", &InitKVSignalPerQuery,
py::arg("seq_lens_encoder_tensor"), py::arg("seq_lens_this_time_tensor"),
py::arg("seq_lens_decoder_tensor"), py::arg("rank"), py::arg("num_layers"),
"init_kv_signal_per_query function");
/**
* GetOutputKVSignal
*/
m.def("get_output_kv_signal", &GetOutputKVSignal,
py::arg("x"), py::arg("rank_id"), py::arg("wait_flag"),
"get_output_kv_signal function");
m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute");
m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute, "MoEDeepGEMMDePermute");
m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute");
m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute,
"MoEDeepGEMMDePermute");
/**
* alloc_cache_pinned.cc
* cuda_host_alloc
@@ -398,12 +465,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
py::arg("moe_quant_type"), "ep moe export dispatch function");
m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8, py::arg("input"),
py::arg("scale"), py::arg("topk_ids"), py::arg("topk_weights"),
py::arg("token_nums_per_expert"),
py::arg("token_nums_per_expert_padded"),
py::arg("token_nums_this_rank"), py::arg("token_nums_this_rank_padded"),
"ep moe export dispatch function");
m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8);
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
@@ -437,6 +499,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
*/
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
/**
* moe/fused_moe/moe_ffn_wint2.cu
* moe_expert_ffn_wint2
*/
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
/**
* moe/fused_moe/moe_expert_reduce.cu
* moe_expert_reduce
@@ -523,4 +591,66 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("group_swiglu_with_masked", &GroupSwigluWithMasked,
"group_swiglu_with_masked function");
m.def("text_image_index_out", &TextImageIndexOut,
"text_image_index_out function");
m.def("text_image_gather_scatter", &TextImageGatherScatter,
"text_image_gather_scatter function");
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
py::arg("a"),
py::arg("c_or_none"),
py::arg("b_q_weight"),
py::arg("b_scales"),
py::arg("global_scale_or_none"),
py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"),
py::arg("perm_or_none"),
py::arg("workspace"),
py::arg("sorted_token_ids"),
py::arg("expert_ids"),
py::arg("num_tokens_post_padded"),
py::arg("topk_weights"),
py::arg("moe_block_size"),
py::arg("top_k"),
py::arg("mul_topk_weights"),
py::arg("is_ep"),
py::arg("b_q_type_str"),
py::arg("size_m"),
py::arg("size_n"),
py::arg("size_k"),
py::arg("is_k_full"),
py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"),
py::arg("is_zp_float"));
/**
* cutlass_scaled_mm.cu
* cutlass_scaled_mm
* cutlass_scaled_mm_azp
*/
m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function");
m.def("cutlass_scaled_mm_azp", &CutlassScaledMmAzp, "cutlass_scaled_mm_azp function");
/**
* quantization/common.cu
* static_scaled_fp8_quant
* dynamic_scaled_fp8_quant
* dynamic_per_token_scaled_fp8_quant
*/
m.def("static_scaled_fp8_quant", &StaticScaledFp8Quant, "static_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scale"));
m.def("dynamic_scaled_fp8_quant", &DynamicScaledFp8Quant,
"dynamic_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scale"));
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
}

View File

@@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Architecture-specific operators on memory added for SM80
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/arch/cache_operation.h"
namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Initiates an asynchronous copy from global memory to shared memory.
///
/// cp.async
///
template <
/// Size of the access in bytes
int SizeInBytes,
/// Cache operation
CacheOperation::Kind cache_op = CacheOperation::Always,
bool GlobalToShared = true>
struct copy;
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
///
/// cp.async
///
template <
/// Size of the access in bytes
int SizeInBytes,
/// Cache operation
CacheOperation::Kind cache_op = CacheOperation::Always,
bool GlobalToShared = true>
struct copy_zfill;
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
///
/// cp.async
///
template <int N, bool GlobalToShared = true>
struct copy_wait;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Always, true> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Always, false> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Always, true> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
cp_async_zfill<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Always, false> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Global, true> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Global, false> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Global, true> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async_zfill<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Global, false> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
template <bool GlobalToShared>
CUTLASS_DEVICE
void copy_fence() {}
template <>
CUTLASS_DEVICE
void copy_fence<true>() {
cp_async_fence();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <int N>
struct copy_wait<N, false> {
CUTLASS_DEVICE
copy_wait() {}
};
/// Partial specialization
template <int N>
struct copy_wait<N, true> {
CUTLASS_DEVICE
copy_wait() { cp_async_wait<N>(); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,460 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcastArray {
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
const Element* const* ptr_row_array = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
int group, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, group(group)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
int group;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.row_broadcast) {
fill(tSR_rRow, *(params.ptr_row_array[group]));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
l,
params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcastArray {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
const Element* const* ptr_col_array = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
int group,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
group(group),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
int group;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col_array[group]));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
l,
params
);
}
};
}

View File

@@ -0,0 +1,500 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
using namespace cute;
using namespace detail;
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->row_broadcast) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrZeroBroadcast {
// This struct has been modified to remove null_default (because it's always 0)
struct Arguments {
Element const* ptr_row = nullptr;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->ptr_row != nullptr) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are broadcasting 0
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
class ThreadMap,
class Element,
class StrideMNL = Stride<_1,_0,_0>
>
struct VisitorColOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage { };
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gCol,
RTensor&& tC_rCol,
CTensor&& tC_cCol,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gCol(cute::forward<GTensor>(tC_gCol)),
tC_rCol(cute::forward<RTensor>(tC_rCol)),
tC_cCol(cute::forward<CTensor>(tC_cCol)),
m(get<0>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gCol;
RTensor tC_rCol;
CTensor tC_cCol;
Params const* params_ptr;
int m;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rCol);
Tensor pred = make_tensor<bool>(shape(tC_gCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tC_cCol(i)) < m;
}
if (params_ptr->col_broadcast) {
// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
} else {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if (pred(i)) {
dst_v(i) = *(params_ptr->ptr_col);
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Array<Element, FragmentSize> frg_col;
frg_col.fill(tC_rCol(row_idx,iter_idx));
return frg_col;
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mCol = make_tensor(
make_gmem_ptr(params_ptr->ptr_col),
problem_shape,
params_ptr->dCol);
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
Tensor tC_gCol = group_modes<1,4>(
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
Tensor tC_rCol = make_tensor_like(tC_gCol);
// Generate the pred tensor
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tC_cCol = group_modes<1,4>(
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
return Callbacks<
decltype(tC_gCol), decltype(tC_rCol),
decltype(tC_cCol), ProblemShape>(
cute::move(tC_gCol),
cute::move(tC_rCol),
cute::move(tC_cCol),
problem_shape,
params_ptr
);
}
};
}

View File

@@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.row_broadcast) {
fill(tSR_rRow, *(params.ptr_row));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcast {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
params
);
}
};
}

View File

@@ -0,0 +1,327 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace fastdeploy::c2x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::Tensor const &tensor) {
using Arguments = typename Descriptor::Arguments;
auto *data_ptr = static_cast<T *>(const_cast<void *>(
tensor.data()));
if constexpr (std::is_same_v<Descriptor,
ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor,
RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
}
else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto *data_ptr =
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
paddle._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::Tensor const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
}; // namespace fastdeploy::c2x

View File

@@ -0,0 +1,453 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
#pragma once
// clang-format will break include orders
// clang-format off
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// clang-format on
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace fastdeploy::c3x {
using namespace cute;
template <typename T> struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;
template <typename... Args> static ArgumentType prepare_args(Args... args) {
return {};
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;
template <typename T>
using ColOrScalarLoadArray =
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoadArray =
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::Tensor const &tensor) {
using Arguments = typename Descriptor::Arguments;
auto *data_ptr = static_cast<T *>(const_cast<void *>(tensor.data()));
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
using Arguments = typename Descriptor::Arguments;
auto *data_ptr =
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
template <typename Descriptor, typename T>
static auto args_from_tensor(const T *const *data_ptr, bool do_broadcast) {
using Arguments = typename Descriptor::Arguments;
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
return Arguments{data_ptr, do_broadcast};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
paddle.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::Tensor const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::Tensor const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
to arrays containing different scales used in group gemm. The number of
pointers in ScaleA and the number of pointers in ScaleB are equal to the
group size.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueArray
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
static ArgumentType prepare_args(float const *const *a_scales_ptr,
float const *const *b_scales_ptr,
bool a_col_broadcast, bool b_row_broadcast) {
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
a_scales_ptr, a_col_broadcast);
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
b_scales_ptr, b_row_broadcast);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
}; // namespace fastdeploy::c3x

View File

@@ -0,0 +1,284 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem
// capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK,
bool SwapAB, int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(
StageCountAutoCarveout<carveout_bytes> stage_count) {
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int stage_bytes = [&]() -> int {
if constexpr (SwapAB) {
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
8 +
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
stage_barrier_bytes;
} else {
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
8 +
stage_barrier_bytes;
}
}();
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType,
template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB,
GmemLayoutB>()>> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>,
"Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm =
(cute::is_same_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only "
"compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>,
tfloat32_t, ElementA>;
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>,
tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA =
detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB =
detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative> ||
IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<MmaElementA, MmaElementB, ElementAccumulator,
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA =
decltype(detail::ss_smem_selector<
GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB =
decltype(detail::ss_smem_selector<
GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages =
detail::compute_stage_count_or_override_gated<
detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<
IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<
PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType,
template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
cute::enable_if_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedFP8FastAccum> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(
detail::is_input_fp8<ElementA, ElementB>(),
"Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(
!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>,
"Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA =
detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB =
detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm =
(cute::is_same_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator,
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA =
decltype(detail::ss_smem_selector<
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB =
decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages =
detail::compute_stage_count_or_override_gated<
detail::sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK,
SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,60 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType,
template <class /* ElementCompute */> class Activation,
bool SwapAB = false, class Enable = void>
struct CollectiveBuilderGated {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,62 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA,
class ElementB, class StrideB, class TiledMma, class GmemTiledCopyA,
class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB,
class TransformB,
template <class /* ElementCompute */> class Activation,
bool SwapAB = false>
struct CollectiveMmaGated {
static_assert(cutlass::detail::dependent_false<ElementA>,
"Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,713 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule,
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_,
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>,
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
SwapAB_> {
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy =
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2,
"Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for "
"this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any
// rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA =
cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB =
cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementB>>>;
using InternalElementAux =
cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<typename TiledMma::ValTypeA,
cute::cosize_v<SmemLayoutA>>
smem_A;
cute::array_aligned<typename TiledMma::ValTypeB,
cute::cosize_v<SmemLayoutB>>
smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const *ptr_A;
StrideA dA;
ElementB const *ptr_B0;
ElementB const *ptr_B1;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params {
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const *>(nullptr),
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const *>(nullptr),
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const &problem_shape,
Arguments const &args, void *workspace) {
(void)workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
// only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const *>(args.ptr_A);
auto ptr_B0 = reinterpret_cast<InternalElementB const *>(args.ptr_B0);
Tensor tensor_a =
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b =
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB) {
auto ptr_Aux = reinterpret_cast<InternalElementA const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(
ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
args.scale_d1};
} else {
auto ptr_Aux = reinterpret_cast<InternalElementB const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(
ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
args.scale_d1};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const &problem_shape,
[[maybe_unused]] Arguments const &args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A =
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B =
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
cute::make_shape(N, K, L), StrideB{});
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
"minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes =
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
8 +
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
8 +
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
/// performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const &mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the
/// contract Returned tuple must contain at least two elements, with the first
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) The rest of the
/// tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
Params const &mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain
// mapping Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB) {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
} else {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
class BlockCoord>
CUTLASS_DEVICE void
load(Params const &mainloop_params, MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage &shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x =
get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a =
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b =
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux =
SwapAB
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(
cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
: gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(
m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB) {
mcast_mask_aux = mcast_mask_a;
} else {
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
TensorStorage &shared_tensors, Params const &mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value,
"C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAux{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.partition_A(sAux);
} else {
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.make_fragment_A(tCsAux);
} else {
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB) {
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
} else {
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
--k_tile_prologue) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum1);
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum1);
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
/// ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
// done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,724 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule,
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_,
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>,
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
SwapAB_> {
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy =
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape,
KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2,
"Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for "
"this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<typename TiledMma::ValTypeA,
cute::cosize_v<SmemLayoutA>>
smem_A;
cute::array_aligned<typename TiledMma::ValTypeB,
cute::cosize_v<SmemLayoutB>>
smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const *ptr_A;
StrideA dA;
ElementB const *ptr_B0;
ElementB const *ptr_B1;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params {
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<ElementA const *>(nullptr),
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, 0),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<ElementB const *>(nullptr),
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, 0),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const &problem_shape,
Arguments const &args, void *workspace) {
(void)workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
// only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<ElementA const *>(args.ptr_A);
auto ptr_B0 = reinterpret_cast<ElementB const *>(args.ptr_B0);
Tensor tensor_a =
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b =
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB) {
auto ptr_Aux = reinterpret_cast<ElementA const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(
ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux,
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
} else {
auto ptr_Aux = reinterpret_cast<ElementB const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(
ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux,
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const &problem_shape,
[[maybe_unused]] Arguments const &args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A =
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B =
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
cute::make_shape(N, K, L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop
* iteration would issue 4 MMA instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
"minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes =
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
8 +
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
8 +
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
/// performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const &mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the
/// contract Returned tuple must contain at least two elements, with the first
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
Params const &mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain
// mapping Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB) {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
} else {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
class BlockCoord>
CUTLASS_DEVICE void
load(Params const &mainloop_params, MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage &shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a =
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b =
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux =
SwapAB
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(
cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
: gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(
m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB) {
mcast_mask_aux = mcast_mask_a;
} else {
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
TensorStorage &shared_tensors, Params const &mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value,
"C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.partition_A(sAux);
} else {
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.make_fragment_A(tCsAux);
} else {
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB) {
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
} else {
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation0(
accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
GmmaFP8Accumulation accumulation1(
accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
--k_tile_prologue) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation0.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation1());
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
if (accumulation0.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation1());
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
/// ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
// done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
accumulation0.promote_residue_if_needed();
accumulation1.promote_residue_if_needed();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
// done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,71 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
////////////////////////////////////////////////////////////////////////////////
/*
* Stateless universal device GEMM kernel type that treats GEMM as
* a composition of a collective mainloop and a collective epilogue.
*
* Supports both the 2.x and 3.x APIs based on whether the first type is
* a cute::tuple<> or not.
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
*
* In the following declaration, the name preceding the 'Or' refers to
* 3.x API type argument order, and the name succeeding the 'Or' refers to
* 2.x API type argument order. Template arguments without two names
* belong to the 3.x API only.
**/
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
class CollectiveMainloopOrEpilogue_,
class CollectiveEpilogueOrThreadblockSwizzle_,
class TileScheduler_ = void, class Enable = void>
class GemmUniversalGated;
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
////////////////////////////////////////////////////////////////////////////////

View File

@@ -130,6 +130,15 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_,
class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedCooperative,
typename CollectiveMainloop_::
DispatchPolicy::Schedule> &&
CollectiveMainloop_::isGated>> {
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or
cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups =
CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock =
CUTE_STATIC_V(size(TiledMma{})) +
(NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load
// threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Kernel level shared memory storage
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
using MainloopPipelineStorage =
typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage =
typename CollectiveEpilogue::PipelineStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
void *workspace{nullptr};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the
// aliased type.
static Params to_underlying_arguments(Arguments const &args,
void *workspace) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments "
"KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
args.hw_info.device_id);
}
CUTLASS_TRACE_HOST(
"to_underlying_arguments(): Setting persistent grid SM count to "
<< sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
void *scheduler_workspace = workspace_ptr;
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *mainloop_workspace = nullptr;
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler.
// Therefore it will be used in separate reduction scheme for streamk case,
// NumEpilogueSubTiles default value is 1, which means subtile will not be
// used, therefore separate reduction will not be enabled.
constexpr uint32_t NumEpilogueSubTiles =
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
return {args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(
args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(
args.problem_shape, args.epilogue, epilogue_workspace),
hw_info,
scheduler,
workspace};
}
static bool can_implement(Arguments const &args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched &&
cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
"meet the requirements.\n");
return implementable;
}
implementable &=
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &=
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const &args) {
size_t workspace_size = 0;
constexpr uint32_t NumEpilogueSubTiles =
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status
initialize_workspace(Arguments const &args, void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
Status status = Status::kSuccess;
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
constexpr uint32_t NumEpilogueSubTiles =
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<ProblemShape,
ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream,
args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const &params) {
// Given device SM count, set grid size s.t. we do not launch more thread
// blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order =
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
TileShape{}, ClusterShape{},
params.hw_info, args);
}
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
CUTLASS_DEVICE
void operator()(Params const &params, char *smem_buf) {
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting "
"sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(
size(TiledMma{}) == 256,
"Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or "
"equal to 128 along the M-dimension.");
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the
* same tile */
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
enum class ProducerWarpRole {
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
int mma_thread_idx = thread_idx % size(TiledMma{});
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Mainloop) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = size(TiledMma{});
mainloop_pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
mainloop_pipeline_params,
ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Epilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
epi_load_pipeline_params.transaction_bytes =
CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id =
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
params_load_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
// scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state =
cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state =
cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = []() {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
} else {
__syncthreads();
return []() {}; // do nothing
}
}();
// Optionally append 1s until problem shape is rank-4 in case it is only
// rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread
// block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TileScheduler scheduler{params.scheduler};
auto work_tile_info = scheduler.get_current_work();
// In a warp specialized kernel, collectives expose data movement and
// compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue,
shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs =
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(
cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop) {
bool do_load_order_arrive = true;
while (work_tile_info.is_valid()) {
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
work_tile_info = fetch_next_work(work_tile_info, scheduler);
continue;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Get the number of K tiles to compute for this work as well as the
// starting K tile offset of the work.
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
work_tile_info, problem_shape_MNKL, blk_shape);
auto work_k_tile_start =
TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter = cute::make_coord_iterator(
idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
collective_mainloop.load(
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx,
block_rank_in_cluster, shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive) {
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline,
mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
collective_epilogue.is_producer_load_needed()) {
while (work_tile_info.is_valid()) {
if (!TileScheduler::requires_separate_reduction(params.scheduler)) {
load_order_barrier.wait();
}
if (TileScheduler::compute_epilogue(work_tile_info,
params.scheduler)) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline, epi_load_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
shared_storage.tensors.epilogue,
work_tile_info.reduction_subtile_idx());
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline,
epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Do we potentially issue tail arrives for TMA stores, if epilogue load
// is waiting for it
bool do_store_tail = false;
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
work_tile_info, problem_shape_MNKL, blk_shape);
// Allocate the accumulators for the (M,N) blk_shape
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto accumulators0 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
auto accumulators1 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
collective_mainloop.mma(
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
accumulators1, work_k_tile_count, mma_thread_idx,
shared_storage.tensors.mainloop, params.mainloop);
// Make sure the math instructions are done and free buffers before
// entering the epilogue
collective_mainloop.mma_tail(mainloop_pipeline,
mainloop_pipe_consumer_state,
work_k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
// Index of warp group within consumer warp groups
int consumer_warp_group_idx =
canonical_warp_group_idx() - NumLoadWarpGroups;
// Perform reduction across splits, if needed
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0,
NumMmaWarpGroups, consumer_warp_group_idx);
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1,
NumMmaWarpGroups, consumer_warp_group_idx);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++) {
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
(scale_d1 * accumulators1[i]);
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next,
epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
work_tile_info.reduction_subtile_idx());
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
do_store_tail = true;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
if (do_store_tail) {
collective_epilogue.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
epi_store_pipe_producer_state);
}
} // Consumer Warp Groups End
#endif
}
private:
// Kernel helper function to get next work unit
CUTLASS_DEVICE
typename TileScheduler::WorkTileInfo
fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info,
TileScheduler &scheduler) const {
// Check whether we should continue on with the current work unit. If this
// is the case, the work unit will have been updated in
// continue_current_work to reflect the new tile to be computed.
if (scheduler.continue_current_work(work_tile_info)) {
return work_tile_info;
}
// Get next work tile
scheduler.advance_to_next_work();
return scheduler.get_current_work();
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@@ -0,0 +1,680 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "cute/util/debug.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_,
class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedPingpong,
typename CollectiveMainloop_::
DispatchPolicy::Schedule> &&
CollectiveMainloop_::isGated>> {
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or
cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
static_assert(ArchTag::kMinComputeCapability >= 90);
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(
!cute::is_same_v<TileScheduler_, StreamKScheduler>,
"Ping-pong kernel does not currently support stream-K scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t MaxThreadsPerBlock =
CUTE_STATIC_V(size(TiledMma{})) +
(NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load
// threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Order Sequence barrier with two stages: one for Mainloop and one for
// Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier =
cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
// Kernel level shared memory storage
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
using MainloopPipelineStorage =
typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage =
typename CollectiveEpilogue::PipelineStorage;
using MathWarpGroupOrderBarrierStorage =
typename MathWarpGroupOrderBarrier::SharedStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the
// aliased type.
static Params to_underlying_arguments(Arguments const &args,
void *workspace) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
(void)workspace;
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments "
"KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
args.hw_info.device_id);
}
CUTLASS_TRACE_HOST(
"to_underlying_arguments(): Setting persistent grid SM count to "
<< sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
void *scheduler_workspace = workspace_ptr;
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *mainloop_workspace = nullptr;
return {args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(
args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(
args.problem_shape, args.epilogue, epilogue_workspace),
hw_info,
TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
args.scheduler, scheduler_workspace)};
}
static bool can_implement(Arguments const &args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched &&
cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
"meet the requirements.\n");
return implementable;
}
implementable &=
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &=
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const &args) {
size_t workspace_size = 0;
workspace_size +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status
initialize_workspace(Arguments const &args, void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
Status status = Status::kSuccess;
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
status = TileScheduler::template initialize_workspace<ProblemShape,
ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream,
args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const &params) {
// Given device SM count, set grid size s.t. we do not launch more thread
// blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order =
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
TileShape{}, ClusterShape{},
params.hw_info, args);
}
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
CUTLASS_DEVICE
void operator()(Params const &params, char *smem_buf) {
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting "
"sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
enum class ProducerWarpRole {
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Mainloop) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
mainloop_pipeline_params,
ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Epilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes =
CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id =
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
params_load_order_barrier);
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id =
canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
params_math_wg_order_barrier.group_size =
NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
// scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state =
cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state =
cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [&]() {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
} else {
__syncthreads();
return []() {}; // do nothing
}
}();
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case it is only
// rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread
// block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
// In a warp specialized kernel, collectives expose data movement and
// compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue,
shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs =
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(
cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
TileScheduler scheduler{params.scheduler};
if (warp_group_role == WarpGroupRole::Consumer1) {
// Advance 2nd Math WG to the next work tile for the startup
scheduler.advance_to_next_work();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
}
auto work_tile_info = scheduler.get_current_work();
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop) {
bool do_load_order_arrive = true;
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
collective_mainloop.load(
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx,
block_rank_in_cluster, shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive) {
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline,
mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
collective_epilogue.is_producer_load_needed()) {
load_order_barrier.wait();
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline, epi_load_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
shared_storage.tensors.epilogue);
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline,
epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Allocate the accumulators for the (M,N) blk_shape
Tensor accumulators0 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
Tensor accumulators1 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
// Order two Math WG's MMA one after the other, helps hide Epilogue
math_wg_order_barrier.wait();
collective_mainloop.mma(
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
accumulators1, k_tile_count, warp_group_thread_idx,
shared_storage.tensors.mainloop, params.mainloop);
// Cue for next Math WG's MMA to start
math_wg_order_barrier.arrive();
// Make sure the math instructions are done and free buffers before
// entering the epilogue
collective_mainloop.mma_tail(
mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++) {
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
(scale_d1 * accumulators1[i]);
}
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier.wait();
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next,
epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, warp_group_thread_idx,
shared_storage.tensors.epilogue);
// TMA store pipeline wait is only visible to TMA-issuing warp, so for
// multiple-consumer kernels we need to wait for all TMA stores to
// complete before issuing consumer order barrier arrives to ensure next
// math consumer doesn't overwrite smem of in-flight TMA stores of
// current consumer.
auto [epi_load_pipe_consumer_state_next_,
epi_store_pipe_producer_state_next_] =
collective_epilogue.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next);
// Update starting load/store pipeline states for the next tile
// state has already been incremented by 1 tile in collective calls,
// advance once again for ping pong
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
// Get next work tile
scheduler.advance_to_next_work(NumMmaWarpGroups);
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@@ -77,6 +77,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
@@ -125,6 +126,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
template <
@@ -148,7 +150,7 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
@@ -179,6 +181,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
template <
@@ -234,6 +237,7 @@ public:
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
template <
@@ -346,6 +350,131 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, El
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -19,13 +19,11 @@
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
@@ -197,6 +195,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
template <
/// Layout type for A matrix operand
@@ -244,6 +243,9 @@ public:
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -265,7 +267,7 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
@@ -296,6 +298,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
template <
/// Layout type for A matrix operand
@@ -318,11 +321,11 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
@@ -348,6 +351,131 @@ public:
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,237 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/threadblock/mma_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount =
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
static_assert((kWarpGemmIterations % 2) == 0,
"Inner loop iteration must be an even number.");
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
// w uint8; local_scale uint8;
constexpr static int kZippedRowsPerStages =
Shape::kK / 4 + (Shape::kK + 127) / 128;
// code_scale float; code_zp float; super_scale ElementB
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
sizeof_bits<typename Operator::ElementB>::value / 8;
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for quanted B operand
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
/// Buffer for unzip B operand
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
operand_unzip_B;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
CUTLASS_HOST_DEVICE
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
CUTLASS_HOST_DEVICE
typename Operator::ElementB *operand_unzip_B_ptr() {
return operand_unzip_B.data();
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2xMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,807 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/arch/memory_copy_sm80.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
// accuracy, where each mainloop iteration first accumulates into a temporary
// set of freshly-cleared accumulators, which are subsequently added to the
// final accumulator set.
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
};
private:
// Structure encapsulating pipeline state live from one iteration to the next
struct PipeState {
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
/// Temporary accumulator to facilitate staged-accumulation
FragmentC tmp_accum_;
/// Pair of A fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentA warp_loaded_frag_A_[2];
WarpTransformedFragmentA warp_transformed_frag_A_[2];
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];
};
private:
//
// Data members
//
/// Warp-level MMA operator
Operator warp_mma_;
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
uint8_t* column_wise_smem_ptr_B_;
uint8_t* smem_zipped_ptr_B_;
int smem_zipped_bytes_per_stage_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2xMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
}
/// Advance shared memory read-iterators to the next stage
CUTLASS_DEVICE
void advance_smem_read_stage()
{
++smem_read_stage_idx_;
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
smem_read_stage_idx_ = 0;
}
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
}
/// Advance global memory read-iterators and shared memory write-iterators to the stage
template <typename TileDequanterB>
CUTLASS_DEVICE
void advance_smem_write_stage(
IteratorA &iterator_A,
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B)
{
// Advance global iterators
iterator_A.add_tile_offset({0, 1});
//iterator_B.add_tile_offset({1, 0});
tile_dequanter_B.AddTileOffset({1, 0});
// Advance shared iterators
smem_iterator_A_.add_tile_offset({0, 1});
//smem_iterator_B_.add_tile_offset({1, 0});
// Increment shared memory write stage index
++smem_write_stage_idx_;
if (smem_write_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx_ = 0;
}
}
CUTLASS_DEVICE
void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) {
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
}
template <bool GlobalToSharedB>
CUTLASS_DEVICE
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
__syncthreads();
}
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) {
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
}
template <bool GlobalToSharedB, bool InitStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
if (InitStage) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
} else {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
}
++iterator_B;
}
++this->smem_iterator_B_;
}
__syncthreads();
}
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
template <typename TileDequanterB>
CUTLASS_DEVICE
void prologue(
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
TileDequanterB &tile_dequanter_B,
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
{
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
// Async copy zipped B to shared memory.
copy_tiles_and_advance_per_stage_A(iterator_A);
// Async copy zipped B to shared memory.
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
// Move to the next write stage
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
}
/// Wait until we have at least one completed global fetch stage
CUTLASS_DEVICE
void gmem_wait()
{
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
}
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
template <typename TileDequanterB>
CUTLASS_DEVICE
void mac_loop_iter(
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
int stage)
{
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
// Load the next warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
// Unpack and dequant the first stage of B.
int unpack_stage = stage - Base::kStages + 2;
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, unpack_stage);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
}
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
}
// Execute the current warp-tile of MMA operations
if (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_
);
if (warp_mma_k == 0) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
pipe_state.tmp_accum_.clear();
}
} else {
warp_mma_(
accum,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
accum
);
}
// Except for the last warp-tile, all warp-tiles issue their share of
// global->shared fragment copies
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
if (warp_mma_k == 0) {
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
}
}
// The second-to-last warp-tile also:
// - performs the last warp-tile's share of global->shared fragment copies
// - moves to the next global fetch stage
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
// Performs the last warp-tile's share of global->shared fragment copies
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Wait until we have at least one completed global fetch stage
gmem_wait();
// Move to the next global fetch stage
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
advance_smem_read_stage();
// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
}
// The last warp-tile also converts the shared memory fragments used by
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
}
}
}
/// Perform the specified number of threadblock mainloop iterations of matrix
/// multiply-accumulate. Assumes prologue has been initiated.
template <typename TileDequanterB>
CUTLASS_DEVICE
void gemm_iters(
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
{
PipeState pipe_state;
// Unpack and dequant the first stage of B.
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
++this->warp_tile_iterator_A_;
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
++this->warp_tile_iterator_B_;
// Transform, if necessary, the first warp-tile's shared memory fragments
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[0],
pipe_state.warp_transformed_frag_B_[0],
pipe_state.warp_loaded_frag_A_[0],
pipe_state.warp_loaded_frag_B_[0]);
if (Detail::kStagedAccumulation) {
pipe_state.tmp_accum_.clear();
}
int stage = Base::kStages - 1;
// Mainloop
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
mac_loop_iter(
pipe_state,
accum,
iterator_A,
iterator_B,
tile_dequanter_B,
gemm_k_iterations,
stage);
stage += 1;
}
if (Detail::kStagedAccumulation) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
/// Prepares the class for another prologue.
CUTLASS_DEVICE
void wind_down()
{
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
// just decrement one tile, but not all iterators implement --() decrement.)
#pragma unroll
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
}
smem_read_stage_idx_++;
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
if (smem_read_stage_idx_ > 1)
{
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
smem_read_stage_idx_ = smem_write_stage_idx_;
}
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
template <typename TileDequanterB>
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC &accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< pre-load and dequantize B to shared memory
TileDequanterB tile_dequanter_B,
///< initial value of accumulator
FragmentC const &src_accum) {
// Prologue (start fetching iterations of global fragments into shared memory)
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
// Wait until we have at least one completed global fetch stage
gmem_wait();
// Initialize destination accumulators with source accumulators
accum = src_accum;
// Perform the MAC-iterations
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/gemm_coord.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
int Stages, int NumThreads, WintQuantMethod Method>
struct TileDequanter {
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
using QuantArguments = typename WeightQuantTraits::Arguments;
using UnzipAndDequantFunctor =
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
static constexpr bool kUseSharedMemory = true;
static constexpr int kRows = Rows;
static constexpr int kColumns = Columns;
static constexpr int kStages = Stages;
MmaElementT *out_smem_ptr{nullptr};
char *pointer{nullptr};
int64_t ldm{0};
cutlass::MatrixCoord tb_offset;
cutlass::MatrixCoord extent;
ScaleElementT *super_scale_ptr{nullptr};
cutlass::MatrixCoord tb_offset_scale;
QuantArguments quant_args;
int64_t block_start_rows[kStages];
bool need_preload{true};
UnzipAndDequantFunctor unzip_functor;
CUTLASS_DEVICE
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
const cutlass::MatrixCoord &extent,
const cutlass::MatrixCoord &tb_offset,
ScaleElementT *super_scale_ptr,
const cutlass::MatrixCoord &tb_offset_scale,
const QuantArguments &quant_args)
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
CUTLASS_DEVICE
MmaElementT *GetOutPtr() { return out_smem_ptr; }
CUTLASS_DEVICE
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
tb_offset.row() += tile_offset.row() * kRows;
tb_offset.column() += tile_offset.column() * kColumns;
tb_offset_scale.column() += tile_offset.column() * kColumns;
}
CUTLASS_DEVICE
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
if (tb_offset.row() >= extent.row() ||
tb_offset.column() >= extent.column()) {
return;
}
block_start_rows[stage % kStages] = tb_offset.row();
using ZippedT = typename WeightQuantTraits::WeightType;
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
tb_offset.column();
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
(tb_offset.row() / 128) * ldm +
tb_offset_scale.column();
const float *code_scale_ptr =
quant_args.code_scale_ptr + tb_offset_scale.column();
const float *code_zp_ptr =
quant_args.code_zp_ptr + tb_offset_scale.column();
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
scale_ptr, &args, ldm, need_preload);
need_preload = false;
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
CUTLASS_DEVICE
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int64_t block_start_row = block_start_rows[stage % kStages];
if (block_start_row >= extent.row()) {
return;
}
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,447 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <cuda_runtime.h>
#include "cutlass/arch/memory.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/wint_type_traits.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename T, int N>
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
template <typename T, WintQuantMethod QuantMethod, int TileRows,
int TileColumns, int NumThreads = 128>
struct UnzipAndDequantFunctor {
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
T *out_ptr, const int64_t in_stride) {}
};
template <typename T, int TileRows, int TileColumns, int NumThreads>
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
TileColumns, NumThreads> {
using ZippedT = uint16_t;
using ScaleComputeT = float;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kZippedGroupSize = 10;
static constexpr int32_t kNumPackedValues = 7;
static constexpr int32_t kWeightMask = 0x7;
static constexpr int32_t kLocalScaleMask = 0x1FFF;
static constexpr int32_t kBZP = 4;
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
ScaleComputeT scale) {
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
int32_t value = shifted_value - kBZP;
ScaleComputeT scaled_value = static_cast<ScaleComputeT>(value) * scale;
return static_cast<T>(scaled_value);
}
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
T *out_ptr, const int64_t in_stride) {
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
int tid = threadIdx.x;
#pragma unroll
for (int col = tid; col < TileColumns; col += NumThreads) {
ScaleComputeT super_scale =
static_cast<ScaleComputeT>(super_scale_ptr[col]);
#pragma unroll
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
// the last row in group
int zipped_row_last = group_id * 10 + 9;
int zipped_offset_last = zipped_row_last * in_stride + col;
int32_t zipped_value_last =
static_cast<int32_t>(in_ptr[zipped_offset_last]);
ScaleComputeT local_scale =
static_cast<ScaleComputeT>(zipped_value_last & kLocalScaleMask);
ScaleComputeT scale = local_scale * super_scale;
#pragma unroll
for (int zipped_row_in_group = 0; zipped_row_in_group < 9;
++zipped_row_in_group) {
int zipped_row = group_id * 10 + zipped_row_in_group;
int zipped_offset = zipped_row * in_stride + col;
int32_t zipped_value = static_cast<int32_t>(in_ptr[zipped_offset]);
int row_in_group = group_id * 64 + zipped_row_in_group * 7;
#pragma unroll
for (int shift_bit_id = 0; shift_bit_id < 7; ++shift_bit_id) {
int32_t shift_bit = shift_bits[shift_bit_id];
T value = Compute(zipped_value, shift_bit, scale);
out_ptr[(row_in_group + shift_bit_id) * TileColumns + col] = value;
}
}
int row_in_group_last = group_id * 64 + 63;
T value_last = Compute(zipped_value_last, shift_bits[0], scale);
out_ptr[row_in_group_last * TileColumns + col] = value_last;
}
}
__syncthreads();
}
};
template <typename T, int TileRows, int TileColumns, int NumThreads>
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
TileColumns, NumThreads> {
using ZippedT = uint8_t;
using ScaleComputeT = float;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kPackNum = 4;
static constexpr int32_t kWeightMask = 0x3F;
static constexpr int32_t kLocalScaleMask = 0xF;
static constexpr int32_t kBZP = 32;
// weight [16, N] uint8_t
// local_scale [1, N] uint8_t
// code_scale [N] float
// code_zp [N] float
// super_scale [N] T
// code_scale, code_zp and super_scale
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
// zipped weights and local_scale
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
struct Arguments {
uint8_t *weight_ptr;
uint8_t *local_scale_ptr;
float *code_scale_ptr;
float *code_zp_ptr;
T *super_scale_ptr;
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
__device__ explicit Arguments(uint8_t *smem_ptr) {
SetZippedPtrs(smem_ptr);
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
}
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
SetZippedPtrs(zipped_smem_ptr);
SetColumnWisePtrs(column_wise_smem_ptr);
}
__device__ void SetZippedPtrs(uint8_t *zipped_smem_ptr) {
weight_ptr = zipped_smem_ptr;
local_scale_ptr = zipped_smem_ptr + (TileRows / 4) * TileColumns;
}
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
}
};
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
const T *g_super_scale_ptr,
Arguments *args, const int64_t in_stride, bool need_preload) {
int tid = threadIdx.x;
#pragma unroll
for (int col = tid; col < TileColumns; col += NumThreads) {
if (need_preload) {
if (g_super_scale_ptr) {
args->super_scale_ptr[col] = g_super_scale_ptr[col];
} else {
args->super_scale_ptr[col] = static_cast<T>(1);
}
args->code_scale_ptr[col] = g_code_scale_ptr[col];
args->code_zp_ptr[col] = g_code_zp_ptr[col];
}
#pragma unroll
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
int local_scale_offset = ls_row_id * in_stride + col;
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
}
#pragma unroll
for (int zipped_row = 0; zipped_row < TileRows / 4; ++zipped_row) {
int s_zipped_offset = zipped_row * TileColumns + col;
int g_zipped_offset = zipped_row * 4 * in_stride + col;
args->weight_ptr[s_zipped_offset] = g_weight_ptr[g_zipped_offset];
}
}
__syncthreads();
}
__device__ void LoadAsync(const uint8_t *g_weight_ptr,
const uint8_t *g_local_scale_ptr,
const float *g_code_scale_ptr,
const float *g_code_zp_ptr,
const T *g_super_scale_ptr,
Arguments *args, const int64_t in_stride, bool need_preload) {
int tid = threadIdx.x;
constexpr int kBytesPerThread = 16; // 16B per thread
constexpr int weight_size = TileRows / 4 * TileColumns;
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
constexpr int code_scale_size = sizeof(float) * TileColumns;
constexpr int code_zp_size = sizeof(float) * TileColumns;
constexpr int super_scale_size = sizeof(T) * TileColumns;
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
constexpr int total_tasks = total_size / kBytesPerThread;
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
static_assert(TileColumns % weight_threads == 0,
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
static_assert(TileColumns % local_scale_threads == 0,
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
if (tid < weight_threads) {
constexpr int weight_per_thread_size = weight_size / weight_threads;
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
}
} else if (tid < weight_threads + local_scale_threads) {
constexpr int start_thread_id = weight_threads;
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
}
} else if (need_preload) {
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
constexpr int start_thread_id = weight_threads + local_scale_threads;
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
}
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
}
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
if (g_super_scale_ptr) {
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
}
}
}
}
}
__device__ void Compute(const Arguments &args, T *out_ptr,
const int64_t block_start_row) {
int32_t shift_bits[4] = {9, 6, 3, 0};
int tid = threadIdx.x;
#pragma unroll
for (int col = tid; col < TileColumns; col += NumThreads) {
ScaleComputeT super_scale =
static_cast<ScaleComputeT>(args.super_scale_ptr[col]);
ScaleComputeT code_scale =
static_cast<ScaleComputeT>(args.code_scale_ptr[col]);
ScaleComputeT code_zp = static_cast<ScaleComputeT>(args.code_zp_ptr[col]);
#pragma unroll
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
int local_scale_offset = (group_id / 2) * TileColumns + col;
int32_t local_scale =
static_cast<int32_t>(args.local_scale_ptr[local_scale_offset]);
ScaleComputeT zipped_value[16];
#pragma unroll
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
int zipped_offset = (group_id * 16 + zipped_row) * TileColumns + col;
zipped_value[zipped_row] =
static_cast<ScaleComputeT>(args.weight_ptr[zipped_offset]);
}
int local_scale_shift = ((block_start_row / 64 + group_id + 1) & 1) * 4;
int32_t shifted_local_scale =
(local_scale >> local_scale_shift) & kLocalScaleMask;
ScaleComputeT scale =
static_cast<ScaleComputeT>(shifted_local_scale) * super_scale;
#pragma unroll
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
int32_t decode_value =
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
static_cast<ScaleComputeT>(0.5)));
int row = group_id * 64 + zipped_row * 4;
#pragma unroll
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
int32_t shift_bit = shift_bits[shift_bit_id];
int32_t shifted_value = (decode_value >> shift_bit) & kWeightMask;
ScaleComputeT value =
static_cast<ScaleComputeT>(shifted_value - kBZP);
out_ptr[(row + shift_bit_id) * TileColumns + col] =
static_cast<T>(scale * value);
}
}
}
}
__syncthreads();
}
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
const int64_t block_start_row) {
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
constexpr int RowStride = NumThreads * N / TileColumns;
constexpr int kNumIters = kNumWeightsPerThread / N;
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
int tid = threadIdx.x;
int begin_col_id = (tid * N) % TileColumns;
int begin_row_id = (tid * N) / TileColumns;
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
UnzipArray<uint8_t, N> local_scales =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
UnzipArray<uint8_t, N> zipped_values[2];
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
zipped_values[0] =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
UnzipArray<T, N> super_scales =
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
UnzipArray<float, N> code_scales =
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
UnzipArray<float, N> code_zps =
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
// special for TileRows = 64
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
UnzipArray<ScaleComputeT, N> scales;
#pragma unroll
for (int i = 0; i < N; ++i) {
int32_t shifted_local_scale =
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
scales[i] =
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
}
#pragma unroll
for (int iter_id = 0; iter_id < kNumIters; ++iter_id) {
int zipped_row = begin_row_id + iter_id * RowStride;
int row = zipped_row * 4;
if (iter_id < kNumIters - 1) {
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
zipped_values[(iter_id + 1) & 1] =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
}
UnzipArray<T, N> outs[4];
#pragma unroll
for (int i = 0; i < N; ++i) {
int32_t decode_value =
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
+ code_zps[i] + decode_value_zp));
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
outs[0][i] = static_cast<T>(scales[i] * value_0);
outs[1][i] = static_cast<T>(scales[i] * value_1);
outs[2][i] = static_cast<T>(scales[i] * value_2);
outs[3][i] = static_cast<T>(scales[i] * value_3);
}
#pragma unroll
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
UnzipArray<T, N> *tmp_out_ptr = reinterpret_cast<UnzipArray<T, N> *>(
out_ptr + (row + shift_bit_id) * TileColumns + begin_col_id);
*tmp_out_ptr = outs[shift_bit_id];
}
}
__syncthreads();
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
namespace cutlass {
enum WintQuantMethod {
kNone = 0,
kWeightOnlyInt8 = 1,
kWeightOnlyInt4 = 2,
kWeightOnlyInt25 = 3,
kWeightOnlyInt2 = 4
};
// Convert CUDA data type to cutlass data type
template <typename T> struct CutlassDataType {
using Type = T;
};
template <> struct CutlassDataType<half> {
using Type = cutlass::half_t;
};
template <> struct CutlassDataType<__nv_bfloat16> {
using Type = cutlass::bfloat16_t;
};
template <typename ElementT, WintQuantMethod Method> struct WintQuantTraits;
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kNone> {
using WeightType = ElementT;
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
static constexpr WintQuantMethod kQuantMethod = WintQuantMethod::kNone;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) { return dim; }
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt8> {
using WeightType = uint8_t;
using MmaKernelType = uint8_t;
using MmaWeightType = uint8_t;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt8;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) { return dim; }
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt4> {
using WeightType = cutlass::uint4b_t;
using MmaKernelType = cutlass::uint4b_t;
using MmaWeightType = cutlass::uint4b_t;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt4;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) { return dim; }
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt25> {
using WeightType = uint16_t;
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt25;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kNumPackedValues = 7;
static constexpr int32_t kPackedSize = 10;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) {
return dim * kPackedSize / kGroupSize;
}
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
using WeightType = uint8_t;
using MmaKernelType = cutlass::uint2b_t;
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt2;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kNumPackedValues = 4;
static constexpr int32_t kPackedSize = 16;
struct Arguments {
const uint8_t *local_scale_ptr; // quanted 4-bits
const float *code_scale_ptr;
const float *code_zp_ptr;
};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) {
return dim * kPackedSize / kGroupSize;
}
};
} // namespace cutlass

View File

@@ -16,106 +16,127 @@
#include <mutex>
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/half.h"
#include "helper.h"
#include "paddle/extension.h"
template <paddle::DataType D>
class CutlassDtypeTraits;
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
PD_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
/**
* A wrapper for a kernel that is used to guard against compilation on
* architectures that will never use the kernel. The purpose of this is to
* reduce the size of the compiled binary.
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template <typename Kernel> struct enable_sm90_or_later : Kernel {
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
template <paddle::DataType D> class CutlassDtypeTraits;
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <>
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
};
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
};
class CutlassGemmConfigMannager {
public:
static CutlassGemmConfigMannager& getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
public:
static CutlassGemmConfigMannager &getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete;
CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) =
delete;
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
CutlassGemmConfigMannager &
operator=(const CutlassGemmConfigMannager &) = delete;
void up_date_configs(const nlohmann::json& j) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = j.begin(); it != j.end(); ++it) {
json_[it.key()] = it.value();
}
void up_date_configs(const nlohmann::json &j) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = j.begin(); it != j.end(); ++it) {
json_[it.key()] = it.value();
}
}
nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) {
if (!load_initialized_) {
std::ifstream file(config_file_path);
if (!file.good()) {
throw std::runtime_error(
"cutlass gemm_best_config can not be found, please set "
"gemm_best_config'path as "
"FLAGS_use_cutlass_device_best_config_path, or unset "
"FLAGS_use_cutlass_device_best_config_path to tune "
"gemm_best_config");
}
json_ = readJsonFromFile(config_file_path);
load_initialized_ = true;
save_initialized_ = false;
}
return &json_;
nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
if (!load_initialized_) {
std::ifstream file(config_file_path);
if (!file.good()) {
throw std::runtime_error(
"cutlass gemm_best_config can not be found, please set "
"gemm_best_config'path as "
"FLAGS_use_cutlass_device_best_config_path, or unset "
"FLAGS_use_cutlass_device_best_config_path to tune "
"gemm_best_config");
}
json_ = readJsonFromFile(config_file_path);
load_initialized_ = true;
save_initialized_ = false;
}
return &json_;
}
private:
void save_gemm_best_configs_(const std::string& config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {
std::ofstream new_file(config_file_path);
new_file << json_.dump(4);
new_file.close();
} else {
nlohmann::json old_json = readJsonFromFile(config_file_path);
for (auto it = json_.begin(); it != json_.end(); ++it) {
old_json[it.key()] = it.value();
}
json_ = old_json;
std::ofstream new_file(config_file_path,
std::ios::out | std::ios::trunc);
new_file << json_.dump(4);
new_file.close();
file.close();
}
return;
private:
void save_gemm_best_configs_(const std::string &config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {
std::ofstream new_file(config_file_path);
new_file << json_.dump(4);
new_file.close();
} else {
nlohmann::json old_json = readJsonFromFile(config_file_path);
for (auto it = json_.begin(); it != json_.end(); ++it) {
old_json[it.key()] = it.value();
}
json_ = old_json;
std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
new_file << json_.dump(4);
new_file.close();
file.close();
}
return;
}
CutlassGemmConfigMannager()
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
~CutlassGemmConfigMannager() {
std::lock_guard<std::mutex> lock(mutex_);
if (save_initialized_) {
std::string config_file_path = "fp8_fuse_gemm_config.json";
save_gemm_best_configs_(config_file_path);
}
save_initialized_ = true;
load_initialized_ = false;
json_.clear();
CutlassGemmConfigMannager()
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
~CutlassGemmConfigMannager() {
std::lock_guard<std::mutex> lock(mutex_);
if (save_initialized_) {
std::string config_file_path = "fp8_fuse_gemm_config.json";
save_gemm_best_configs_(config_file_path);
}
mutable std::mutex mutex_;
nlohmann::json json_;
bool load_initialized_;
bool save_initialized_;
save_initialized_ = true;
load_initialized_ = false;
json_.clear();
}
mutable std::mutex mutex_;
nlohmann::json json_;
bool load_initialized_;
bool save_initialized_;
};

View File

@@ -15,8 +15,8 @@
#pragma once
#include "fp8_common.h"
#include "fuse_dual_gemm_swiglu_template.h"
#include "fuse_dual_gemm_act_template_3x.h"
#include "fuse_dual_gemm_geglu_template.h"
#include "fuse_dual_gemm_swiglu_template.h"
bool fp8_fp8_dual_gemm_scale_bias_act(
DualGemmEpilogueAllParams params);
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params);

View File

@@ -15,12 +15,13 @@
#pragma once
#include "fp8_common.h"
#include "fuse_gemm_gelu_template.h"
#include "fuse_gemm_noact_template.h"
#include "fuse_gemm_relu_template.h"
#include "fuse_gemm_gelu_template.h"
#include "fuse_block_gemm_act_template_3x.h"
#include "fuse_gemm_act_template_3x.h"
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params);
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);

View File

@@ -0,0 +1,173 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/float8.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "fp8_common.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp"
#include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp"
template <typename InputType, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType,
typename TileSchedulerType = void,
template <class /* ElementCompute */> class Activation =
cutlass::epilogue::thread::SiLu,
bool SwapAB = true>
bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
using namespace cute;
using ElementA = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<
ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = ElementA; // Element type for B matrix operand
using LayoutB =
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB =
128 /
cutlass::sizeof_bits<
ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
using ElementC = ElementA; // Element type for C matrix operands
using LayoutC = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
static constexpr int AlignmentC =
128 /
cutlass::sizeof_bits<
ElementC>::value; // Memory access granularity/alignment of C matrices
// in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = ElementA; // Element type for output matrix operands
// using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output
// matrix operands
using LayoutOutput = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
static constexpr int AlignmentOutput =
128 / cutlass::sizeof_bits<ElementOutput>::value;
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for compute
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = CTAShape; // Threadblock-level tile size
using KernelSchedule = MainloopScheduleType;
using EpilogueSchedule = EpilogueScheduleType;
using TileScheduler = TileSchedulerType;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using FusionOperation =
cutlass::epilogue::fusion::ScaledAcc<ElementOutput, ElementCompute>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC,
ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule,
FusionOperation>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilderGated<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule, Activation, SwapAB>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
int arg_m = params.M;
int arg_n = params.N;
ElementA const *ptr_A = reinterpret_cast<ElementA const *>(params.A);
ElementB const *ptr_B0 = reinterpret_cast<ElementB const *>(params.B0);
ElementB const *ptr_B1 = reinterpret_cast<ElementB const *>(params.B1);
if constexpr (SwapAB) {
arg_m = params.N;
arg_n = params.M;
ptr_A = reinterpret_cast<ElementB const *>(params.B0);
ptr_B0 = reinterpret_cast<ElementA const *>(params.A);
}
StrideA stride_A = cutlass::make_cute_packed_stride(
StrideA{}, cute::make_shape(arg_m, params.K, params.batch_count));
StrideB stride_B = cutlass::make_cute_packed_stride(
StrideB{}, cute::make_shape(arg_n, params.K, params.batch_count));
StrideC stride_C;
StrideD stride_D = cutlass::make_cute_packed_stride(
StrideD{}, cute::make_shape(arg_m, arg_n, params.batch_count));
typename Gemm::Arguments arguments = {
cutlass::gemm::GemmUniversalMode::kGemm,
{arg_m, arg_n, params.K, params.batch_count},
{ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1},
{{}, // epilogue.thread
nullptr,
stride_C,
reinterpret_cast<ElementOutput *>(params.D),
stride_D}};
arguments.epilogue.thread.alpha = params.scale_out;
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::can_implement() failed" << std::endl;
return false;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
phi::Allocator *allocator = paddle::GetAllocator(params.place);
auto workspace = allocator->Allocate(workspace_size);
//
// Run the GEMM
//
status = gemm_op(arguments, workspace->ptr(), params.stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::run() failed" << std::endl;
return false;
}
return true;
}

View File

@@ -0,0 +1,151 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fp8_common.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/packed_stride.hpp"
template <
typename InputType,
typename OutType,
bool hasbias,
template <class> typename Activation,
typename TileShape,
typename ClusterShape,
typename KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum,
typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized,
typename SM = cutlass::arch::Sm90>
bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
using namespace cute;
using ElementA = typename std::conditional_t<
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
using ElementB = ElementA;
using ElementD =
typename std::conditional_t<std::is_same_v<OutType, phi::dtype::bfloat16>,
cutlass::bfloat16_t, cutlass::half_t>;
using ElementC = std::conditional_t<hasbias, ElementD, void>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
// 16B alignment lets us use TMA
static constexpr int AlignmentA = 16 / sizeof(ElementA);
static constexpr int AlignmentB = 16 / sizeof(ElementB);
static constexpr int AlignmentC = hasbias ? 16 / sizeof(ElementC) : 8;
static constexpr int AlignmentD = 16 / sizeof(ElementD);
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using FusionOperation =
cutlass::epilogue::fusion::LinCombEltAct<Activation, ElementD,
ElementCompute, ElementC,
ElementScalar, RoundStyle>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Data members
//
/// Initialization
StrideA stride_A{params.lda, cute::Int<1>{}, params.M * params.lda};
StrideB stride_B{params.ldb, cute::Int<1>{}, params.N * params.ldb};
StrideC stride_C{0, cute::Int<1>{}, 0};
StrideD stride_D{params.ldd, cute::Int<1>{}, params.ldd * params.M};
auto a_ptr = reinterpret_cast<ElementA *>(const_cast<void *>(params.A));
auto b_ptr = reinterpret_cast<ElementB *>(const_cast<void *>(params.B));
auto c_ptr = reinterpret_cast<ElementC *>(const_cast<void *>(params.bias));
auto d_ptr = reinterpret_cast<ElementD *>(params.D);
ProblemShapeType problem_size =
ProblemShapeType{params.M, params.N, params.K, params.batch_count};
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{a_ptr, stride_A, b_ptr, stride_B},
{{params.scale}, // epilogue.thread
c_ptr,
stride_C,
d_ptr,
stride_D}};
if constexpr (hasbias) {
arguments.epilogue.thread.beta = 1.0;
}
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cout << "Gemm::can_implement() failed. "
<< cutlassGetStatusString(status) << std::endl;
return false;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
phi::Allocator *allocator = paddle::GetAllocator(params.place);
auto workspace = allocator->Allocate(workspace_size);
status = gemm_op(arguments, workspace->ptr(), params.stream);
if (status != cutlass::Status::kSuccess) {
std::cout << "Gemm::run() failed." << cutlassGetStatusString(status)
<< std::endl;
return false;
}
return true;
}

View File

@@ -43,7 +43,9 @@
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@@ -156,9 +158,6 @@ struct MoeFCGemm {
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
@@ -209,6 +208,13 @@ struct MoeFCGemm {
int64_t gemm_n;
int64_t gemm_k;
WintQuantMethod quant_method;
// Extra arguments for wint2.0
uint8_t* local_scale;
float* code_scale;
float* code_zp;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
@@ -230,6 +236,10 @@ struct MoeFCGemm {
total_rows(-1),
gemm_n(0),
gemm_k(0),
quant_method(WintQuantMethod::kNone),
local_scale(nullptr),
code_scale(nullptr),
code_zp(nullptr),
host_problem_sizes(nullptr) {}
/// Ctor
@@ -246,6 +256,10 @@ struct MoeFCGemm {
int64_t total_rows,
int64_t gemm_n,
int64_t gemm_k,
WintQuantMethod quant_method,
const uint8_t* local_scale,
const float* code_scale,
const float* code_zp,
GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count),
threadblock_count(threadblock_count),
@@ -259,8 +273,12 @@ struct MoeFCGemm {
total_rows(total_rows),
gemm_n(gemm_n),
gemm_k(gemm_k),
quant_method(quant_method),
local_scale(const_cast<uint8_t*>(local_scale)),
code_scale(const_cast<float*>(code_scale)),
code_zp(const_cast<float*>(code_zp)),
host_problem_sizes(nullptr) {
if (platform::is_same<uint8_t, ElementB>::value ||
if (quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
assert(weight_scales);
}
@@ -284,6 +302,8 @@ struct MoeFCGemm {
ElementC* ptr_C;
ElementC* ptr_D;
WintQuantMethod quant_method;
//
// Methods
//
@@ -294,7 +314,8 @@ struct MoeFCGemm {
ptr_B(nullptr),
weight_scales(nullptr),
ptr_C(nullptr),
ptr_D(nullptr) {}
ptr_D(nullptr),
quant_method(WintQuantMethod::kNone) {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args,
@@ -313,7 +334,8 @@ struct MoeFCGemm {
ptr_B(args.ptr_B),
weight_scales(args.weight_scales),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D) {}
ptr_D(args.ptr_D),
quant_method(args.quant_method) {}
CUTLASS_HOST_DEVICE
void update(Arguments const& args,
@@ -334,6 +356,7 @@ struct MoeFCGemm {
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
quant_method = args.quant_method;
}
};
@@ -358,7 +381,7 @@ struct MoeFCGemm {
}
static Status can_implement(Arguments const& args) {
if (platform::is_same<uint8_t, ElementB>::value ||
if (args.quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
platform::is_same<uint4b_t, ElementB>::value) {
if (args.weight_scales == nullptr) {
CUTLASS_TRACE_HOST(
@@ -394,6 +417,7 @@ struct MoeFCGemm {
template <typename dummy>
struct KernelRunner<true, dummy> {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
@@ -401,12 +425,14 @@ struct MoeFCGemm {
// These types shadow the type-level definitions and support the ability
// to implement a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(
@@ -435,6 +461,7 @@ struct MoeFCGemm {
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// threadblock_offset of C
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
@@ -450,6 +477,7 @@ struct MoeFCGemm {
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
}
// begin address offset for A for current tile
ElementA* ptr_A =
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
@@ -463,14 +491,17 @@ struct MoeFCGemm {
: gemm_k * kInterleave;
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
@@ -610,6 +641,381 @@ struct MoeFCGemm {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled
/// for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to //
/// NOLINT perform
>
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_> {
public:
using Base = MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_>;
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments = typename Base::MapArguments;
// Public-facing type definitions related to operand element type, layout, and
// complex conjugate operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC =
Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = typename Base::ProblemVisitor;
using Arguments = typename Base::Arguments;
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params : Base::Params {
// Extra arguments for wint2.0
uint8_t* local_scale;
float* code_scale;
float* code_zp;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args,
void* workspace = nullptr,
int tile_count = 0) // NOLINT
: Base::Params(args, workspace, tile_count),
local_scale(args.local_scale),
code_scale(args.code_scale),
code_zp(args.code_zp) {}
CUTLASS_HOST_DEVICE
void update(Arguments const& args,
void* workspace = nullptr,
int tile_count = 0) {
Base::update(args, workspace, tile_count);
local_scale = args.local_scale;
code_scale = args.code_scale;
code_zp = args.code_zp;
}
};
/// Shared memory storage structure
using SharedStorage = typename Base::SharedStorage;
public:
//
// Methods
//
CUTLASS_DEVICE
Wint2xMoeFCGemm() {}
static Status can_implement(Arguments const& args) {
if (args.quant_method != WintQuantMethod::kWeightOnlyInt2) {
CUTLASS_TRACE_HOST(
"Wint2xMoeFCGemm::can_implement() - only support weight_only_int2!");
return Status::kInvalid;
} else if (args.weight_scales == nullptr || args.local_scale == nullptr) {
CUTLASS_TRACE_HOST(
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!");
return Status::kInvalid;
}
return Status::kSuccess;
}
// The dummy template parameter is not used and exists so that we can compile
// this code using a standard earlier than C++17. Prior to C++17, fully
// specialized templates HAD to exists in a namespace
template <WintQuantMethod QuantMethod, bool B, typename dummy = void>
struct KernelRunner {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
CUTLASS_NOT_IMPLEMENTED();
}
};
template <WintQuantMethod QuantMethod, typename dummy>
struct KernelRunner<QuantMethod, true, dummy> {
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
using QuantArguments = typename WeightQuantTraits::Arguments;
CUTLASS_DEVICE
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
QuantArguments quant_args;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
}
return quant_args;
}
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
//
// These types shadow the type-level definitions and support the ability
// to implement a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using QuantElementB = typename WeightQuantTraits::WeightType;
using MmaElementB = typename WeightQuantTraits::MmaWeightType;
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(
platform::is_same<LayoutB, layout::RowMajor>::value &&
kInterleave == 1 ||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// LayoutB should be RowMajor
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
// wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64.
const int64_t quant_gemm_k = WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits<QuantElementB>::value;
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// threadblock_offset of C
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
0);
// begin address offset for weight_scale.
ElementScale* weight_scale_ptr =
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Load element pointers. Exchange pointers and strides if working on
// the transpose
int64_t rows_to_jump = 0;
if (params.problem_visitor.total_rows < 0) {
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
} else {
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
}
// begin address offset for A for current tile
ElementA* ptr_A =
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
// begin address offset for B for current problem_idx, totally num_experts problems
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix; // NOLINT
typename LayoutB::LongIndex ldm_B =
platform::is_same<layout::RowMajor, LayoutB>::value
? gemm_n
: gemm_k * kInterleave;
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
MmaElementB* smem_unzip_B_ptr = nullptr;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
}
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
byte_ptr_B,
ldm_B,
extent_B,
tb_offset_B,
weight_scale_ptr,
tb_offset_scale,
quant_args);
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
ptr_B,
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
thread_idx,
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the
// previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
tile_dequanter_B,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C =
params.ptr_C ? reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n : nullptr;
ElementC* ptr_D =
reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
LayoutC layout_C(0);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C,
ptr_C,
problem_size.mn(),
thread_idx,
threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D,
ptr_D,
problem_size.mn(),
thread_idx,
threadblock_offset.mn());
Epilogue epilogue(
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/*
To improve compilation speed, we do not compile the device operator if the
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params,
SharedStorage& shared_storage) { // NOLINT
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -15,16 +15,22 @@
*/
#pragma once
#include <cuda_runtime_api.h>
#include <string>
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
#include "cutlass_extensions/wint_type_traits.h"
namespace phi {
template <typename T, /*The type used for activations/scales/compute*/
typename WeightType /* The type for the MoE weights */>
typename WeightQuantTraits /* The quant traits for the MoE weights */>
class MoeGemmRunner {
public:
using WeightType = typename WeightQuantTraits::WeightType;
using Arguments = typename WeightQuantTraits::Arguments;
MoeGemmRunner();
void moe_gemm_bias_act(const T* A,
@@ -38,6 +44,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
std::string activation_type,
cudaStream_t stream);
@@ -51,6 +58,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
cudaStream_t stream);
private:
@@ -65,6 +73,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy = nullptr);
@@ -81,6 +90,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const Arguments& quant_args_B,
cudaStream_t stream);
private:

View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
@@ -22,7 +22,8 @@
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
template class MoeGemmRunner<
__nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>;
#endif
} // namespace phi
} // namespace phi

View File

@@ -0,0 +1,30 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
#include "helper.h"
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<
__nv_bfloat16,
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
#endif
} // namespace phi

View File

@@ -21,7 +21,9 @@
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
template class MoeGemmRunner<
__nv_bfloat16,
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
#endif
} // namespace phi
} // namespace phi

View File

@@ -22,8 +22,9 @@
namespace phi {
#ifdef PADDLE_CUDA_BF16
template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
template class MoeGemmRunner<
__nv_bfloat16,
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
#endif
} // namespace phi
} // namespace phi

View File

@@ -21,6 +21,7 @@
namespace phi {
template class MoeGemmRunner<half, half>;
template class MoeGemmRunner<half,
cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kNone>>;
} // namespace phi
} // namespace phi

View File

@@ -0,0 +1,27 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
#include "helper.h"
namespace phi {
template class MoeGemmRunner<
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
} // namespace phi

View File

@@ -21,6 +21,7 @@
namespace phi {
template class MoeGemmRunner<half, cutlass::uint4b_t>;
template class MoeGemmRunner<
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
} // namespace phi
} // namespace phi

View File

@@ -21,6 +21,7 @@
namespace phi {
template class MoeGemmRunner<half, uint8_t>;
template class MoeGemmRunner<
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
} // namespace phi
} // namespace phi

View File

@@ -24,9 +24,10 @@
#include <math.h>
#include <optional>
#include <sstream>
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/array.h"
#include "cutlass/trace.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
@@ -35,8 +36,11 @@
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/wint_type_traits.h"
#include "cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
@@ -48,17 +52,47 @@
#include "helper.h"
namespace phi {
// ============================= Variable batched Gemm things
// ===========================
template <typename MixedGemmArchTraits, cutlass::WintQuantMethod Method>
struct CutlassLayoutB {
using Type = typename MixedGemmArchTraits::LayoutB;
};
template <typename MixedGemmArchTraits>
struct CutlassLayoutB<MixedGemmArchTraits, cutlass::WintQuantMethod::kNone> {
using Type = cutlass::layout::RowMajor;
};
template <typename BaseGemmKernel, typename Arch, cutlass::WintQuantMethod Method>
struct CutlassGemmKernel {
using Type =
cutlass::gemm::kernel::MoeFCGemm<typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
};
template <typename BaseGemmKernel, typename Arch>
struct CutlassGemmKernel<BaseGemmKernel, Arch, cutlass::WintQuantMethod::kWeightOnlyInt2> {
using Type =
cutlass::gemm::kernel::Wint2xMoeFCGemm<typename BaseGemmKernel::Mma,
typename BaseGemmKernel::Epilogue,
typename BaseGemmKernel::ThreadblockSwizzle,
Arch,
BaseGemmKernel::kGroupScheduleMode>;
};
// ======================= Variable batched Gemm things =======================
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
void generic_moe_gemm_kernelLauncher(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -67,6 +101,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
const int multi_processor_count,
cudaStream_t stream,
@@ -86,44 +121,26 @@ void generic_moe_gemm_kernelLauncher(const T* A,
"Specialized for half, float");
#endif
using WeightType = typename WeightQuantTraits::WeightType;
static_assert(
cutlass::platform::is_same<T, WeightType>::value ||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value ||
cutlass::platform::is_same<WeightType, uint16_t>::value,
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)");
// The cutlass type for the input elements. This is needed to convert to
// cutlass::half_t if necessary.
using ElementType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<T, half>::value,
cutlass::half_t,
T>::type;
#ifdef PADDLE_CUDA_BF16
using ElementType = typename cutlass::platform::conditional<
cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
cutlass::bfloat16_t,
ElementType_>::type;
#else
using ElementType = ElementType_;
#endif
using CutlassWeightType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<WeightType, half>::value,
cutlass::half_t,
WeightType>::type;
#ifdef PADDLE_CUDA_BF16
using CutlassWeightType = typename cutlass::platform::conditional<
cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
cutlass::bfloat16_t,
CutlassWeightType_>::type;
#else
using CutlassWeightType = CutlassWeightType_;
#endif
using ElementType = typename cutlass::CutlassDataType<T>::Type;
using CutlassWeightType = typename cutlass::CutlassDataType<typename WeightQuantTraits::WeightType>::Type;
using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType;
using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType;
// We need separate config for each architecture since we will target
// different tensorcore instructions. For float, we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::
MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
MixedGemmArchTraits<ElementType, CutlassMmaKernelType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename Epilogue<ElementType,
@@ -132,13 +149,13 @@ void generic_moe_gemm_kernelLauncher(const T* A,
EpilogueTag>::Op;
// Finally, set up the kernel.
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<
using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
ElementType,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessA,
CutlassWeightType,
typename MixedGemmArchTraits::LayoutB,
CutlassMmaKernelType,
typename CutlassLayoutB<MixedGemmArchTraits, WeightQuantTraits::kQuantMethod>::Type,
cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessB,
ElementType,
@@ -155,14 +172,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
typename MixedGemmArchTraits::Operator>::GemmKernel;
using GemmKernel =
cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma,
typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
arch, // Ensure top level arch is used
// for dispatch
GemmKernel_::kGroupScheduleMode>;
using GemmKernel = typename CutlassGemmKernel<BaseGemmKernel, arch, WeightQuantTraits::kQuantMethod>::Type;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
if (kernel_occupancy != nullptr) {
@@ -181,19 +191,32 @@ void generic_moe_gemm_kernelLauncher(const T* A,
typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f),
ElementAccumulator(0.f));
const uint8_t* local_scale_B = nullptr;
const float* code_scale_B = nullptr;
const float* code_zp_B = nullptr;
if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
local_scale_B = quant_args_B.local_scale_ptr;
code_scale_B = quant_args_B.code_scale_ptr;
code_zp_B = quant_args_B.code_zp_ptr;
}
typename GemmGrouped::Arguments args(
num_experts,
threadblock_count,
epilogue_op,
reinterpret_cast<const ElementType*>(A),
reinterpret_cast<const CutlassWeightType*>(B),
reinterpret_cast<const CutlassMmaWeightType*>(B),
reinterpret_cast<const ElementType*>(weight_scales),
reinterpret_cast<const ElementType*>(biases),
reinterpret_cast<ElementType*>(C),
total_rows_before_expert,
total_rows,
gemm_n,
gemm_k);
gemm_k,
WeightQuantTraits::kQuantMethod,
local_scale_B,
code_scale_B,
code_zp_B);
GemmGrouped gemm;
@@ -222,7 +245,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
}
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
@@ -231,7 +254,7 @@ template <typename T,
typename Enable = void>
struct dispatch_stages {
static void dispatch(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -240,6 +263,7 @@ struct dispatch_stages {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
@@ -253,20 +277,20 @@ struct dispatch_stages {
};
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
struct dispatch_stages<T,
WeightType,
WeightQuantTraits,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
2> {
static void dispatch(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -275,12 +299,13 @@ struct dispatch_stages<T,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_moe_gemm_kernelLauncher<T,
WeightType,
WeightQuantTraits,
arch,
EpilogueTag,
ThreadblockShape,
@@ -295,6 +320,7 @@ struct dispatch_stages<T,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
gemm_config,
multi_processor_count,
stream,
@@ -303,13 +329,13 @@ struct dispatch_stages<T,
};
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
struct dispatch_stages<T,
WeightType,
WeightQuantTraits,
cutlass::arch::Sm80,
EpilogueTag,
ThreadblockShape,
@@ -317,7 +343,7 @@ struct dispatch_stages<T,
Stages,
typename std::enable_if<(Stages > 2)>::type> {
static void dispatch(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -326,12 +352,13 @@ struct dispatch_stages<T,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_moe_gemm_kernelLauncher<T,
WeightType,
WeightQuantTraits,
cutlass::arch::Sm80,
EpilogueTag,
ThreadblockShape,
@@ -346,6 +373,7 @@ struct dispatch_stages<T,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
gemm_config,
multi_processor_count,
stream,
@@ -354,13 +382,13 @@ struct dispatch_stages<T,
};
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
void dispatch_gemm_config(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -369,6 +397,7 @@ void dispatch_gemm_config(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
@@ -376,7 +405,7 @@ void dispatch_gemm_config(const T* A,
#define dispatch_stages_macro(STAGE) \
case STAGE: \
dispatch_stages<T, \
WeightType, \
WeightQuantTraits, \
arch, \
EpilogueTag, \
ThreadblockShape, \
@@ -391,6 +420,7 @@ void dispatch_gemm_config(const T* A,
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
multi_processor_count, \
stream, \
@@ -414,7 +444,7 @@ void dispatch_gemm_config(const T* A,
case CutlassTileConfig:: \
CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \
dispatch_gemm_config<T, \
WeightType, \
WeightQuantTraits, \
arch, \
EpilogueTag, \
cutlass::gemm::GemmShape<AA, BB, CC>, \
@@ -425,10 +455,11 @@ void dispatch_gemm_config(const T* A,
biases, \
C, \
total_rows_before_expert, \
total_rows, \
total_rows, \
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
multi_processor_count, \
stream, \
@@ -438,14 +469,14 @@ void dispatch_gemm_config(const T* A,
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
// This overload is only enabled when T == WeightType.
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value &&
std::is_same<T, WeightType>::value>::type* =
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -454,6 +485,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
@@ -474,7 +506,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
default:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for same "
"type MoE tensorop GEMM.");
"type MoE tensorop GEMM for FP16/BF16.");
break;
}
}
@@ -483,14 +515,14 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
// Overload for quantize MoE GEMMs. We disable some warp configs here since they
// will not be used and we can improve compile time
template <typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value &&
!std::is_same<T, WeightType>::value>::type* =
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -499,28 +531,34 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr) {
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM.");
break;
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM for sm70.");
break;
}
} else {
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
}
} else {
switch (gemm_config.tile_config) {
@@ -555,12 +593,12 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
template <
typename T,
typename WeightType,
typename WeightQuantTraits,
typename arch,
typename EpilogueTag,
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -569,6 +607,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
int sm_version,
int multi_processor_count,
@@ -594,8 +633,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
}
}
template <typename T, typename WeightType>
MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
template <typename T, typename WeightQuantTraits>
MoeGemmRunner<T, WeightQuantTraits>::MoeGemmRunner() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
sm_ = getSMVersion();
@@ -603,11 +642,11 @@ MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
}
template <typename T, typename WeightType>
template <typename T, typename WeightQuantTraits>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
void MoeGemmRunner<T, WeightQuantTraits>::dispatch_to_arch<EpilogueTag>(
const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
@@ -616,11 +655,12 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy) {
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
dispatch_moe_gemm_to_cutlass<T, WeightType, ARCH, EpilogueTag>( \
dispatch_moe_gemm_to_cutlass<T, WeightQuantTraits, ARCH, EpilogueTag>( \
A, \
B, \
weight_scales, \
@@ -631,6 +671,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
gemm_n, \
gemm_k, \
num_experts, \
quant_args_B, \
gemm_config, \
sm_, \
multi_processor_count_, \
@@ -648,25 +689,28 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
}
}
template <typename T, typename WeightType>
template <typename T, typename WeightQuantTraits>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t tune_total_rows,
int64_t actual_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
cudaStream_t stream) {
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
static constexpr bool is_weight_only = !std::is_same<T, typename WeightQuantTraits::WeightType>::value;
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
std::vector<CutlassGemmConfig> candidate_configs =
get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true);
static constexpr int warm_time = 5;
static constexpr int test_time = 10;
auto& gemmConfigManager = GemmConfigManager::Instance();
@@ -676,17 +720,19 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n, gemm_k, GemmType::MOEGEMM, dtype, wdtype, num_experts};
CutlassGemmConfig chosen_config;
auto chosen_config_optional =
gemmConfigManager.getBestConfig(gemmId, tune_total_rows);
gemmConfigManager.getBestConfig(gemmId, actual_total_rows);
if (chosen_config_optional != std::nullopt) {
chosen_config = chosen_config_optional.value();
} else {
size_t best_id = -1;
float best_time = std::numeric_limits<float>::max();
CutlassGemmConfig best_config;
int profile_total_rows =
std::min(gemmConfigManager.nextPowerOfTwo(tune_total_rows),
std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows),
gemmConfigManager.getMaxProfileM());
bool find_one = false;
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
size_t num_candidate_configs_size = candidate_configs.size();
for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) {
try {
for (int i = 0; i < warm_time; i++) {
dispatch_to_arch<EpilogueTag>(A,
@@ -699,6 +745,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n,
gemm_k,
num_experts,
quant_args_B,
candidate_configs[ii],
stream);
}
@@ -719,6 +766,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n,
gemm_k,
num_experts,
quant_args_B,
candidate_configs[ii],
stream);
}
@@ -728,7 +776,9 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
check_cuda_error(cudaEventDestroy(start));
check_cuda_error(cudaEventDestroy(stop));
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
if (elapsed < best_time) {
best_id = ii;
best_time = elapsed;
best_config = candidate_configs[ii];
}
@@ -739,6 +789,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
}
}
if (find_one) {
//std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl;
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
chosen_config = best_config;
} else {
@@ -756,23 +807,25 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
gemm_n,
gemm_k,
num_experts,
quant_args_B,
chosen_config,
stream);
}
template <typename T, typename WeightType>
void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
template <typename T, typename WeightQuantTraits>
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm_bias_act(
const T* A,
const WeightType* B,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t tune_total_rows,
int64_t actual_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
std::string activation_type,
cudaStream_t stream) {
if (activation_type == "none") {
@@ -784,10 +837,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
C,
total_rows_before_expert,
total_rows,
tune_total_rows,
actual_total_rows,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
stream);
} else {
run_gemm<EpilogueOpNoBias>(A,
@@ -797,27 +851,30 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
C,
total_rows_before_expert,
total_rows,
tune_total_rows,
actual_total_rows,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
stream);
}
}
}
template <typename T, typename WeightType>
void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
const WeightType* B,
const T* weight_scales,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t tune_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
cudaStream_t stream) {
template <typename T, typename WeightQuantTraits>
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm(
const T* A,
const typename WeightQuantTraits::WeightType* B,
const T* weight_scales,
T* C,
int64_t* total_rows_before_expert,
int64_t total_rows,
int64_t actual_total_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
const typename WeightQuantTraits::Arguments& quant_args_B,
cudaStream_t stream) {
run_gemm<EpilogueOpNoBias>(A,
B,
weight_scales,
@@ -825,10 +882,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
C,
total_rows_before_expert,
total_rows,
tune_total_rows,
actual_total_rows,
gemm_n,
gemm_k,
num_experts,
quant_args_B,
stream);
}

View File

@@ -0,0 +1,102 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "helper.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_helper.h"
// clang-format on
namespace fastdeploy::c3x {
static inline cute::Shape<int, int, int, int>
get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) {
int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1];
return {m, n, k, 1};
}
template <typename GemmKernel>
void cutlass_gemm_caller(
phi::Place device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
cutlass::KernelHardwareInfo hw_info;
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
mainloop_args,
epilogue_args,
hw_info,
scheduler};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
phi::Allocator *allocator = paddle::GetAllocator(device);
auto workspace = allocator->Allocate(workspace_size);
auto stream = paddle::GetCurrentCUDAStream(device)->raw_stream();
cutlass::Status status = gemm_op.run(args, workspace->ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;
using ElementD = typename Gemm::ElementD;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = StrideC;
using StrideAux = StrideC;
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
auto [M, N, K, L] = prob_shape;
StrideA a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
StrideB b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
StrideC c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
StrideD d_stride =
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
StrideAux aux_stride = d_stride;
auto a_ptr = static_cast<ElementAB *>(const_cast<void *>(a.data()));
auto b_ptr = static_cast<ElementAB *>(const_cast<void *>(b.data()));
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD *>(const_cast<void *>(out.data()));
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, d_stride};
cutlass_gemm_caller<GemmKernel>(a.place(), prob_shape, mainloop_args,
epilogue_args);
}
} // namespace fastdeploy::c3x

View File

@@ -0,0 +1,149 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_helper.h"
#include "helper.h"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
must contain a public type named EVTCompute of type Sm90EVT, as well as a
static prepare_args function that constructs an EVTCompute::Arguments struct.
*/
using namespace cute;
namespace fastdeploy {
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm_sm100 {
using ElementAB = ElementAB_;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementD_>::value;
using ElementD = ElementD_;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
// MMA type
using ElementAccumulator = float;
// Epilogue types
using ElementBias = cutlass::half_t;
using ElementCompute = float;
using ElementAux = ElementD;
using LayoutAux = LayoutD;
using ElementAmax = float;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};
} // namespace fastdeploy

View File

@@ -0,0 +1,27 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
// clang-format will break include orders
// clang-format off
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on
namespace fastdeploy {
void cutlass_scaled_mm_azp_sm90_int8(
paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b,
paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj, paddle::optional<paddle::Tensor> const &azp,
paddle::optional<paddle::Tensor> const &bias) {
if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
*azp, bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,34 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
#include "helper.h"
template <typename Fp8Func, typename Int8Func>
void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b, paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias,
Fp8Func fp8_func, Int8Func int8_func) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1];
if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) &&
(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) {
// Standard per-tensor/per-token/per-channel scaling
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
PD_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"No kernel for this combination of input dtypes is implemented."));
}
}

View File

@@ -0,0 +1,35 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
#pragma once
#include "helper.h"
namespace fastdeploy {
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
} // namespace fastdeploy

View File

@@ -0,0 +1,28 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
// clang-format will break include orders
// clang-format off
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on
namespace fastdeploy {
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,125 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
// clang-format on
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* shape.
*/
namespace fastdeploy {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_default {
// M in (128, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.dims()[0];
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_epilogue(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...epilogue_args) {
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,29 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
// clang-format will break include orders
// clang-format off
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on
namespace fastdeploy {
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,168 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
#pragma once
// clang-format will break include orders
// clang-format off
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
// clang-format on
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* Gemm shape.
*/
namespace fastdeploy {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_default {
// For M > 128 and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M128 {
// For M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M64 {
// For M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NBig {
// For M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NSmall {
// For M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.dims()[1];
bool const is_small_n = n < 8192;
uint32_t const m = a.dims()[0];
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out,
paddle::Tensor const &a,
paddle::Tensor const &b,
EpilogueArgs &&...epilogue_args) {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,200 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
#include "helper.h"
#include <stddef.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace fastdeploy;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == paddle::DataType::INT8) {
PD_CHECK(b.dtype() == paddle::DataType::INT8);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
if (out.dtype() == paddle::DataType::BFLOAT16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (bias) {
PD_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -0,0 +1,223 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
#pragma once
#include <stddef.h>
#include "helper.h"
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass_helper.h"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace fastdeploy {
using namespace cute;
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
FP8MathOperator>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
float, cutlass::layout::RowMajor, AlignmentCD,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.dims()[0];
int32_t n = b.dims()[0];
int32_t k = a.dims()[1];
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.strides()[0];
int64_t ldb = b.strides()[0];
int64_t ldc = out.strides()[0];
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data());
auto b_ptr = static_cast<ElementAB const*>(b.data());
auto c_ptr = static_cast<ElementD*>(out.data());
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
phi::Allocator *allocator = paddle::GetAllocator(a.place());
auto workspace = allocator->Allocate(workspace_size);
auto stream = a.stream();
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace->ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static const int max_shared_mem_per_block_opt_in =
get_cuda_max_shared_memory_per_block_opt_in(0);
size_t const gemm_shared_mem_size =
sizeof(typename Gemm::KernelType::SharedStorage);
size_t const fallback_gemm_shared_mem_size =
sizeof(typename FallbackGemm::KernelType::SharedStorage);
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
PD_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,125 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_default {
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M256 {
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M64 {
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M32 {
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM256 =
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
using Cutlass2xGemmM64 =
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using FallbackGemm = Cutlass2xGemmDefault;
uint32_t const m = a.dims()[0];;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// M in [1, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,141 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_default {
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M64 {
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M32 {
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M16 {
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128BigN =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128SmallN =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM64 =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM16 =
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using FallbackGemm =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
uint32_t const m = a.dims()[0];;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
uint32_t const n = out.dims()[1];;
bool const small_n = n < 8192;
if (small_n) {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else {
// M in (128, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,370 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5,
FP8MathOperator>;
};
struct sm89_fp8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M128 {
// M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8196) {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
static const int32_t MainLoopStages = 5;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 24576) {
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
uint32_t const m = a.dims()[0];;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,355 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace fastdeploy {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_int8_fallback_gemm {
// Shared mem requirement : 61440
static_assert(std::is_same<InType, int8_t>());
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
static int32_t const MainLoopStages = 5;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
struct sm89_int8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M128 {
// M in (64, 128]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[1];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.dims()[0];
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return fastdeploy::fallback_cutlass_gemm_caller<
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(paddle::Tensor& out,
paddle::Tensor const& a,
paddle::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
PD_CHECK(a.dtype() == paddle::DataType::INT8);
PD_CHECK(b.dtype() == paddle::DataType::INT8);
uint32_t const m = a.dims()[0];
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace fastdeploy

View File

@@ -0,0 +1,37 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
fastdeploy::cutlass_scaled_mm_sm90_fp8,
fastdeploy::cutlass_scaled_mm_sm90_int8);
}
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& out, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
fastdeploy::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -0,0 +1,224 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
#pragma once
#include "helper.h"
#include <iostream>
void cutlass_scaled_mm_sm75(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_sm80(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
void cutlass_scaled_mm_sm89(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b,
paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias);
#endif
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a,
paddle::Tensor const &b, paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::optional<paddle::Tensor> const &bias) {
// Checks for conformality
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
c.dims().size() == 2);
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
b.dims()[0] == c.dims()[1]);
// Check for strides and alignment
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
PD_CHECK(b.strides()[1] == 1); // Column-major
PD_CHECK(c.strides()[0] % 16 == 0 &&
b.strides()[0] % 16 == 0); // 16 Byte Alignment
if (bias) {
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous() &&
bias->dims().size() == 1);
}
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: %d",
version_num));
}
void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a,
paddle::Tensor const& b,
paddle::Tensor const& a_scales,
paddle::Tensor const& b_scales,
paddle::Tensor const& azp_adj,
paddle::optional<paddle::Tensor> const& azp,
paddle::optional<paddle::Tensor> const& bias) {
// Checks for conformality
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
c.dims().size() == 2);
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
b.dims()[0] == c.dims()[1]);
PD_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]);
PD_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0]);
// Check for strides and alignment
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
PD_CHECK(b.strides()[1] == 1); // Column-major
PD_CHECK(c.strides()[0] % 16 == 0 &&
b.strides()[0] % 16 == 0); // 16 Byte Alignment
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous());
}
if (azp) {
PD_CHECK(azp->numel() == a.dims()[0] && azp->is_contiguous());
}
PD_CHECK(azp_adj.numel() == b.dims()[0] && azp_adj.is_contiguous());
// azp & bias types
PD_CHECK(azp_adj.dtype() == paddle::DataType::INT32);
PD_CHECK(!azp || azp->dtype() == paddle::DataType::INT32);
PD_CHECK(!bias || bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
PD_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: %d",
version_num));
}
PD_BUILD_STATIC_OP(cutlass_scaled_mm)
.Inputs({"c", "a", "b", "a_scales", "b_scales", paddle::Optional("bias")})
.Outputs({"c_out"})
.SetInplaceMap({{"c", "c_out"}})
.SetKernelFn(PD_KERNEL(CutlassScaledMm));
PD_BUILD_STATIC_OP(cutlass_scaled_mm_azp)
.Inputs({"c", "a", "b", "a_scales", "b_scales", "azp_adj", paddle::Optional("azp"), paddle::Optional("bias")})
.Outputs({"c_out"})
.SetInplaceMap({{"c", "c_out"}})
.SetKernelFn(PD_KERNEL(CutlassScaledMmAzp));

View File

@@ -1,27 +0,0 @@
# DeepGEMM
DeepGEMM 安装流程
## Installation
首先安装自定义算子确保cutlass已经`git clone`到[custom_ops/third_party/cutlass](../../third_party/cutlass)
安装deep_gemm:
```bash
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Add the project path to PYTHONPATH
export PYTHONPATH=$(pwd):$PYTHONPATH
# or install directly
python setup.py install
```
### Test
```bash
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py
```

View File

@@ -1,31 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""This module contains all JIT kernels used in GEMM"""
from . import jit
from .jit_kernels import (
ceil_div,
gemm_fp8_fp8_bf16_nt,
get_col_major_tma_aligned_tensor,
get_col_major_tma_aligned_tensor_prefill,
get_m_alignment_for_contiguous_layout,
get_num_sms,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
set_num_sms,
)
from .utils import bench, calc_diff, get_cuda_home

View File

@@ -1,462 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The file has been adapted from DeepSeek DeepGEMM project
// Copyright (c) 2025 DeepSeek
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "mma_utils.cuh"
#include "scheduler.cuh"
#include "tma_utils.cuh"
#include "utils.cuh"
namespace deep_gemm {
enum class Layout {
RowMajor,
ColMajor
};
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast,
GemmType kGemmType>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b;
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
}
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(kNumIterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Issue TMA A with broadcasting
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast,
GemmType kGemmType>
class Gemm {
private:
using Barrier = cuda::barrier<cuda::thread_scope_block>;
public:
Gemm() = default;
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
// Launch
auto status = cudaLaunchKernelEx(&config, kernel,
gmem_d, scales_b, grouped_layout,
shape_m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
}
template <typename T>
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
return make_2d_tma_desc(global_address, Layout::ColMajor,
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
min(BLOCK_M, shape_m), BLOCK_N,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor,
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_desc(
T* global_address, Layout layout,
uint32_t gmem_rows, uint32_t gmem_cols,
uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
};
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -1,903 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The file has been adapted from DeepSeek DeepGEMM project
// Copyright (c) 2025 DeepSeek
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
#pragma once
#include <cuda.h>
#include "utils.cuh"
namespace deep_gemm {
struct SM90_64x16x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
" %8,"
" %9,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 16;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x24x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %14, 0;\n"
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11},"
" %12,"
" %13,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 24;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x32x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15},"
" %16,"
" %17,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 32;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x40x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 40;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x48x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23},"
" %24,"
" %25,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 48;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x56x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %30, 0;\n"
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27}, "
" %28,"
" %29,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 56;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x64x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31}, "
" %32,"
" %33,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 64;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x72x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %38, 0;\n"
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35}, "
" %36,"
" %37,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 72;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x80x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %42, 0;\n"
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39}, "
" %40,"
" %41,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 80;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x88x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %46, 0;\n"
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43}, "
" %44,"
" %45,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 88;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x96x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %50, 0;\n"
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47}, "
" %48,"
" %49,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 96;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x104x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %54, 0;\n"
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51}, "
" %52,"
" %53,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 104;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x112x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %58, 0;\n"
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55}, "
" %56,"
" %57,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 112;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x120x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %62, 0;\n"
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59}, "
" %60,"
" %61,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 120;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x128x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %66, 0;\n"
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63}, "
" %64,"
" %65,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 128;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x192x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %98, 0;\n"
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71, "
" %72, %73, %74, %75, %76, %77, %78, %79, "
" %80, %81, %82, %83, %84, %85, %86, %87, "
" %88, %89, %90, %91, %92, %93, %94, %95}, "
" %96,"
" %97,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 192;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
}
};
template <typename dtype_t>
struct SM90_U32x4_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
}
};
__device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
__device__ void warpgroup_commit_batch() {
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
}
__device__ void warpgroup_fence_operand(float& reg) {
asm volatile("" : "+f"(reg) :: "memory");
}
__forceinline__ __device__ uint32_t get_lane_id() {
uint32_t lane_id;
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
uint32_t ret;
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
int4 ret;
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
float ret;
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
template <int N>
__device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
}
union GmmaDescriptor {
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
uint16_t reg16_[4];
struct {
uint16_t start_address_: 14, : 2;
uint16_t leading_byte_offset_: 14, : 2;
uint16_t stride_byte_offset_: 14, : 2;
uint8_t : 1, base_offset_: 3, : 4;
uint8_t : 6, layout_type_: 2;
} bitfield;
// Decay to an `uint64_t`
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
};
template <class PointerType>
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
int leading_byte_offset = 0,
int stride_byte_offset = 1024) {
GmmaDescriptor desc;
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
template <int N>
struct FP8MMASelector {
static constexpr auto select_type() {
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
}
using type = decltype(select_type());
};
} // namespace deep_gemm

View File

@@ -1,121 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The file has been adapted from DeepSeek DeepGEMM project
// Copyright (c) 2025 DeepSeek
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
#include "utils.cuh"
namespace deep_gemm {
enum class GemmType {
Normal,
GroupedContiguous,
GroupedMasked
};
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
uint32_t num_aligned_m_blocks;
// For normal GEMM
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;
// For grouped GEMM
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
int* grouped_layout = nullptr) {
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
} else if (kGemmType == GemmType::GroupedContiguous) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::GroupedMasked) {
curr_group_idx = curr_cumsum = 0;
this->grouped_layout = grouped_layout;
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
}
template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if (kGemmType == GemmType::GroupedContiguous) {
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
return offset * shape_dim + block_idx * block_size;
} else if (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
}
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
if constexpr (kGemmType == GemmType::GroupedMasked) {
uint32_t num_m_blocks;
while (true) {
// End of the task
if (curr_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;
// Move to check the next group
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
} else {
if (next_block_idx >= num_blocks)
return false;
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -1,116 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The file has been adapted from DeepSeek DeepGEMM project
// Copyright (c) 2025 DeepSeek
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
#pragma once
#include <cassert>
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/barrier>
#include "utils.cuh"
namespace deep_gemm {
template <class T>
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
if constexpr (std::is_same<T, uint8_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, uint16_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
} else if constexpr (std::is_same<T, uint32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
} else if constexpr (std::is_same<T, uint64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
} else if constexpr (std::is_same<T, int32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT32;
} else if constexpr (std::is_same<T, int64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT64;
} else if constexpr (std::is_same<T, __half>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if constexpr (std::is_same<T, float>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if constexpr (std::is_same<T, double>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
}
}
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;
/*
#if CUDA_VERSION >= 12050
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
cudaEnableDefault, &driver_status);
#else
*/
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
cudaEnableDefault, &driver_status);
//#endif
if (driver_status != cudaDriverEntryPointSuccess)
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
}
template <typename T>
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map{};
constexpr uint32_t rank = 2;
uint64_t global_stride[rank - 1] = {stride_in_bytes};
uint32_t elem_strides[rank] = {1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
int32_t const& crd_0, int32_t const& crd_1) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
if constexpr (kNumTMAMulticast == 1) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
} else if (cute::block_rank_in_cluster() == 0) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
}
}
} // namespace deep_gemm

View File

@@ -1,66 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The file has been adapted from DeepSeek DeepGEMM project
// Copyright (c) 2025 DeepSeek
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
#pragma once
#include <exception>
#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif
class AssertionException : public std::exception {
private:
std::string message{};
public:
explicit AssertionException(const std::string& message) : message(message) {}
const char *what() const noexcept override { return message.c_str(); }
};
#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", \
__FILE__, __LINE__, #cond); \
throw AssertionException("Assertion failed: " #cond); \
} \
} while (0)
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif
template <typename T>
__device__ __host__ constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}

View File

@@ -1,21 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""Compiler module"""
from .compiler import build, get_nvcc_compiler
from .runtime import Runtime
from .template import cpp_format, generate

View File

@@ -1,208 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""compiler"""
import functools
import hashlib
import os
import re
import subprocess
import uuid
from typing import Tuple
from ..utils import get_cuda_home
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache()
def hash_to_hex(s: str) -> str:
"""Hash string s into hexadecimal format"""
md5 = hashlib.md5()
md5.update(s.encode("utf-8"))
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def get_jit_include_dir() -> str:
"""Get jit include dir"""
return f"{os.path.dirname(os.path.abspath(__file__))}/../include"
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
"""Get deep gemm version"""
# Update include directories
include_dir = f"{get_jit_include_dir()}/deep_gemm"
assert os.path.exists(
include_dir
), f"Cannot find GEMM include directory {include_dir}"
md5 = hashlib.md5()
for filename in filter(
lambda x: x.endswith(".cuh"), sorted(os.listdir(include_dir))
):
with open(f"{include_dir}/{filename}", "rb") as f:
md5.update(f.read())
# Update `interleave_ffma.py`
with open(
f"{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py", "rb"
) as f:
md5.update(f.read())
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def get_nvcc_compiler() -> Tuple[str, str]:
"""Get default NVCC compiler"""
paths = []
if os.getenv("DG_NVCC_COMPILER"):
paths.append(os.getenv("DG_NVCC_COMPILER"))
CUDA_HOME = get_cuda_home()
paths.append(f"{CUDA_HOME}/bin/nvcc")
# Try to find the first available NVCC compiler
least_version_required = "12.3"
version_pattern = re.compile(r"release (\d+\.\d+)")
for path in paths:
if os.path.exists(path):
match = version_pattern.search(os.popen(f"{path} --version").read())
version = match.group(1)
assert match, f"Cannot get the version of NVCC compiler {path}"
assert (
version >= least_version_required
), f"NVCC {path} version {version} is lower than {least_version_required}"
return path, version
raise RuntimeError("Cannot find any available NVCC compiler")
@functools.lru_cache(maxsize=None)
def get_default_user_dir():
"""Get default user dir"""
if "DG_CACHE_DIR" in os.environ:
path = os.getenv("DG_CACHE_DIR")
os.makedirs(path, exist_ok=True)
return path
return os.path.expanduser("~") + "/.deep_gemm"
@functools.lru_cache(maxsize=None)
def get_tmp_dir():
"""Get temporary dir"""
return f"{get_default_user_dir()}/tmp"
@functools.lru_cache(maxsize=None)
def get_cache_dir():
"""Get cache dir"""
return f"{get_default_user_dir()}/cache"
def make_tmp_dir():
"""Make temporary dir"""
tmp_dir = get_tmp_dir()
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
def put(path, data, is_binary=False):
"""Put data to specified path"""
# Write and do POSIX atomic replace
tmp_file_path = f"{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}"
with open(tmp_file_path, "wb" if is_binary else "w") as f:
f.write(data)
os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
"""Build kernel"""
# Compiler flags
nvcc_flags = [
"-std=c++17",
"-shared",
"-O3",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-gencode=arch=compute_90a,code=sm_90a",
"--ptxas-options=--register-usage-level=10"
+ (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
"--diag-suppress=177,174,940",
]
cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi"]
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]
# Build signature
enable_sass_opt = (
get_nvcc_compiler()[1] <= "12.8"
and int(os.getenv("DG_DISABLE_FFMA_INTERLEAVE", 0)) == 0
)
signature = f"{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}"
name = f"kernel.{name}.{hash_to_hex(signature)}"
path = f"{get_cache_dir()}/{name}"
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
if os.getenv("DG_JIT_DEBUG", None):
print(f"Using cached JIT runtime {name} during build")
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
args_path = f"{path}/kernel.args"
src_path = f"{path}/kernel.cu"
put(
args_path,
", ".join(
[f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]
),
)
put(src_path, code)
# Compile into a temporary SO file
so_path = f"{path}/kernel.so"
tmp_so_path = (
f"{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so"
)
# Compile
command = [
get_nvcc_compiler()[0],
src_path,
"-o",
tmp_so_path,
*flags,
*[f"-I{d}" for d in include_dirs],
]
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_JIT_PRINT_NVCC_COMMAND", False):
print(f"Compiling JIT runtime {name} with command {command}")
return_code = subprocess.check_call(command)
assert return_code == 0, f"Failed to compile {src_path}"
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_so_path)
# Atomic replace SO file
os.replace(tmp_so_path, so_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]

View File

@@ -1,173 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""interleave ffma"""
import argparse
import mmap
import os
import re
import subprocess
from ..utils import get_cuda_home
def run_cuobjdump(file_path):
"""Run cuobjdump on the given file path and returns its output as a string."""
CUDA_HOME = get_cuda_home()
command = [f"{CUDA_HOME}/bin/cuobjdump", "-sass", file_path]
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
assert result.returncode == 0
return result.stdout
def extract_ffma(sass):
"""Extract all FFMA instructions from the SASS code."""
lines = sass.splitlines()
collected = []
current = []
arch_name, func_name = "N/A", "N/A"
skip_next_line = False
for line in lines:
if "code for" in line:
arch_name = line.lstrip().lstrip("code for ").rstrip()
elif "Function :" in line:
func_name = line.lstrip().lstrip("Function :").rstrip()
elif "FFMA" in line:
current.append(line)
skip_next_line = True
elif skip_next_line:
current.append(line)
skip_next_line = False
else:
if len(current) >= 16:
assert len(current) % 2 == 0
collected.append((f"{arch_name}::{func_name}", current))
current = []
if os.getenv("DG_PRINT_REG_REUSE", None):
print(f"Found {len(collected)} FFMA segments")
return collected
def extract_hex_from_line(line):
"""Extract hexadecimal number from the given line using regular expression."""
match = re.search(r"/\*\s*(0x[0-9a-fA-F]+)\s*\*/", line)
assert match
return int(match.group(1), 16)
def validate(m, offset, le_bytes, num_lines):
"""Validate that the memory region contains the expected bytes starting from the specified offset"""
assert len(le_bytes) == num_lines // 2
assert m[offset : offset + 16] == le_bytes[0]
for i in range(1, num_lines // 2):
if m[offset + i * 16 : offset + i * 16 + 16] != le_bytes[i]:
return False
return True
def parse_registers(line):
"""Parse register names from the given line"""
line = re.sub(r"/\*.*?\*/", "", line)
line = line.replace(";", "")
tokens = line.strip().split(",")
registers = []
for token in tokens:
token = token.strip()
words = token.split()
for word in words:
if word.startswith("R"):
reg = word.split(".")[0]
registers.append(reg)
return registers
def modify_segment(m, name, ffma_lines):
"""Modify the segment so it can be used multiple times"""
num_lines = len(ffma_lines)
assert num_lines % 2 == 0
le_bytes, new_le_bytes = [], []
reused_list = []
dst_reg_set = set()
last_reused, last_dst_reg = False, ""
num_changed = 0
for i in range(num_lines // 2):
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(
high_line
)
le_bytes.append(low_hex.to_bytes(8, "little") + high_hex.to_bytes(8, "little"))
reused = (high_hex & 0x0800000000000000) != 0
if reused:
is_first_occurred = dst_reg not in dst_reg_set
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
# Modify the `reuse` and `yield` bits
assert high_hex & 0x0800200000000000, f"{hex(high_hex)}"
high_hex ^= 0x0800200000000000
reused = False
num_changed += 1
else:
reused_list.append(i)
dst_reg_set.add(dst_reg)
new_le_bytes.append(
low_hex.to_bytes(8, "little") + high_hex.to_bytes(8, "little")
)
last_reused, last_dst_reg = reused, dst_reg
if os.getenv("DG_PRINT_REG_REUSE", None):
print(
f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}"
)
# Find the offset
offsets = []
offset = m.find(le_bytes[0])
while offset != -1:
offsets.append(offset)
offset = m.find(le_bytes[0], offset + 1)
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
# Replace with `new_le_bytes`
for offset in offsets:
for i in range(num_lines // 2):
m[offset + i * 16 : offset + i * 16 + 16] = new_le_bytes[i]
def process(path):
"""Process the given path"""
if os.getenv("DG_PRINT_REG_REUSE", None):
print(f"Processing {path}")
output = run_cuobjdump(path)
segments = extract_ffma(output)
with open(path, "r+b") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
for segment in segments:
modify_segment(mm, *segment)
mm.close()
if __name__ == "__main__":
"""Main function"""
parser = argparse.ArgumentParser(description="Interleave FFMA reg reuse")
parser.add_argument("--so", help="Path to the SO file")
args = parser.parse_args()
process(args.so)

View File

@@ -1,100 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""Runtime"""
import ctypes
import os
from typing import Optional
import paddle
from paddle import Tensor
from .template import map_ctype
class Runtime:
"""A callable class that wraps CUDA kernel execution"""
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.args = None
assert self.is_path_valid(self.path)
@staticmethod
def is_path_valid(path: str) -> bool:
"""Check whether the given path contains all necessary files"""
# Exists and is a directory
if not os.path.exists(path) or not os.path.isdir(path):
return False
# Contains all necessary files
files = ["kernel.cu", "kernel.args", "kernel.so"]
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, *args) -> int:
"""Call the wrapped function"""
# Load SO file
if self.lib is None or self.args is None:
self.lib = ctypes.CDLL(os.path.join(self.path, "kernel.so"))
with open(os.path.join(self.path, "kernel.args"), "r") as f:
self.args = eval(f.read(), {"paddle": paddle})
# Check args and launch
assert len(args) == len(
self.args
), f"Expected {len(self.args)} arguments, got {len(args)}"
cargs = []
for arg, (name, dtype) in zip(args, self.args):
if isinstance(arg, Tensor):
assert (
arg.dtype == dtype
), f"Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`"
else:
assert isinstance(
arg, dtype
), f"Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`"
cargs.append(map_ctype(arg))
return_code = ctypes.c_int(0)
self.lib.launch(*cargs, ctypes.byref(return_code))
return return_code.value
class RuntimeCache:
"""A cache for Runtimes"""
def __init__(self) -> None:
self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]:
"""Get a cached Runtime"""
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = Runtime(path)
self.cache[path] = runtime
return runtime
return None
def __setitem__(self, path, runtime) -> None:
"""Set a new Runtime"""
self.cache[path] = runtime

View File

@@ -1,150 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""Template"""
import copy
import ctypes
import os
from typing import Any, Dict, Iterable, Tuple
import paddle
from paddle import Tensor
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
paddle.int32: "paddle.int32",
paddle.float32: "paddle.float32",
paddle.bfloat16: "paddle.bfloat16",
paddle.float8_e4m3fn: "paddle.float8_e4m3fn",
paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f"c_{t.__name__}") for t in (bool, int, float)},
**{
t: ctypes.c_void_p
for t in (
paddle.int32,
paddle.float32,
paddle.bfloat16,
paddle.float8_e4m3fn,
paddle.device.cuda.Stream,
)
},
}
# Type map for both Python API and source code usages
genc_map = {
bool: ("bool", "bool"),
int: ("int", "int"),
float: ("float", "float"),
paddle.int32: ("void*", "int*"),
paddle.float32: ("void*", "float*"),
paddle.bfloat16: ("void*", "__nv_bfloat16*"),
paddle.float8_e4m3fn: ("void*", "__nv_fp8_e4m3*"),
paddle.device.cuda.Stream: ("void*", "cudaStream_t"),
}
def map_ctype(value: Any) -> Any:
"""Map python types to corresponding ctypes"""
ctype = ctype_map[value.dtype if isinstance(value, Tensor) else type(value)]
if isinstance(value, Tensor):
return ctype(value.data_ptr())
if isinstance(value, paddle.device.cuda.Stream):
return ctype(value.cuda_stream)
return ctype(value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
"""Format template string using given dict"""
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
new_template = new_template.replace(f"{{{key}}}", f"{value}")
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
"""Generate CPP source code"""
# Common prefix
code = "// DeepGEMM auto-generated JIT CUDA source file\n\n"
# Includes
preload_sys_includes = [
"<cuda.h>",
"<cuda_fp8.h>",
"<cuda_runtime.h>",
"<iostream>",
]
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(
list(
set(
preload_sys_includes
+ [include for include in includes if include.startswith("<")]
)
)
)
package_includes = sorted(
list(
set(
preload_package_includes
+ [include for include in includes if include.startswith('"')]
)
)
)
code += "\n".join(f"#include {include}" for include in sys_includes) + "\n\n"
code += "\n".join(f"#include {include}" for include in package_includes) + "\n\n"
# Function signature
raw = "__raw_"
get_def = (
lambda n, t: f"{genc_map[t][0]} "
+ (raw if genc_map[t][0] != genc_map[t][1] else "")
+ n
)
code += 'extern "C" void launch('
code += ", ".join(
[get_def(*arg_def) for arg_def in arg_defs]
+ [
"int& __return_code",
]
)
code += ") {\n"
# Cast raw types
code += " // Cast raw types (if needed)\n"
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f" auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n"
# Function body
code += "\n".join([((" " if line else "") + line) for line in body.split("\n")])
# End the function
code += "}\n\n"
# Debug print
if os.getenv("DG_JIT_DEBUG", None):
print(f"Generated code:\n{code}")
return code

View File

@@ -1,31 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""initialize"""
from .gemm import gemm_fp8_fp8_bf16_nt
from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
from .utils import (
ceil_div,
get_col_major_tma_aligned_tensor,
get_col_major_tma_aligned_tensor_prefill,
get_m_alignment_for_contiguous_layout,
get_num_sms,
set_num_sms,
)

View File

@@ -1,266 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""FP8 GEMM kernels"""
import functools
from typing import Tuple
import paddle
from paddle import Tensor
from .tuner import jit_tuner
from .utils import (
ceil_div,
get_m_alignment_for_contiguous_layout,
get_num_sms,
)
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"',)
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def is_tma_multicast_legal(
n: int, block_n: int, num_tma_multicast: int, num_sms: int
) -> bool:
"""Check whether it's legal to have multiple multicasts per SM."""
if num_tma_multicast == 1:
return True
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def get_smem_size(
num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128
) -> int:
"""Get shared memory size needed for each stage"""
smem_d = block_m * block_n * 2
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = ceil_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
return smem_size
def get_best_configs(
m: int,
n: int,
k: int,
num_groups: int,
num_sms: int,
is_grouped_contiguous: bool = False,
) -> Tuple[int, int, int, int, int]:
"""Find the optimal configuration"""
if not is_grouped_contiguous:
# TODO: for some cases, smaller M block is better, add them into tuning space
block_ms = (64 if m <= 64 else 128,)
else:
block_ms = (get_m_alignment_for_contiguous_layout(),)
block_ns = tuple(range(16, 129, 8))
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (
ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms)
if bm
else None
)
get_last_wave_util = lambda bm, bn: fix_wave_saturate(
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms
)
# Decide block sizes by waves
best_block_m, best_block_n = None, None
for block_m in block_ms:
for block_n in block_ns:
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(
best_block_m, best_block_n
)
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util or (
util == best_util
and (
block_m > best_block_m
or (block_m == best_block_m and block_n < best_block_n)
)
)
best_block_m, best_block_n = (
(block_m, block_n) if success else (best_block_m, best_block_n)
)
assert best_block_m is not None and best_block_n is not None
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
assert best_num_stages is not None
# Decide the number of TMA multicast
best_num_tma_multicast = 1
if (
m >= 1024
and is_tma_multicast_legal(n, best_block_n, 2, num_sms)
and num_groups == 1
):
best_num_tma_multicast = 2
return (
best_block_m,
best_block_n,
best_num_stages,
best_num_tma_multicast,
best_smem_size,
)
@functools.lru_cache()
def auto_tuning_with_compilation(m, n, k):
"""Compile and tune the GEMM"""
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
m, n, k, 1, num_sms
)
runtime = jit_tuner.compile_and_tune(
m,
n,
k,
name="gemm_fp8_fp8_bf16_nt",
keys={
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"K": k,
"N": n,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": num_tma_multicast,
},
space=(),
includes=includes,
arg_defs=(
("lhs", paddle.float8_e4m3fn),
("lhs_scales", paddle.float32),
("rhs", paddle.float8_e4m3fn),
("rhs_scales", paddle.float32),
("out", paddle.bfloat16),
("m", int),
("stream", paddle.device.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=template,
)
return runtime, num_sms, smem_size
def gemm_fp8_fp8_bf16_nt(
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor
) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow Paddle operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
# m_, n_ = out.shape
# assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
# assert m == m_ and n == n_ and k == k_
# assert n > 0 and k > 0
# assert lhs_scales.shape == [m, (k + 127) // 128]
# assert rhs_scales.shape == [(n + 127) // 128, (k + 127) // 128]
# assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
# assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
# assert out.dtype == paddle.bfloat16
# assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
# TODO: NOT NEED get_col_major_tma_aligned_tensor!!!
# lhs_scales = get_col_major_tma_aligned_tensor_prefill(lhs_scales)
# assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
return
runtime, num_sms, smem_size = auto_tuning_with_compilation(m, n, k)
args = (
lhs,
lhs_scales,
rhs,
rhs_scales,
out,
m,
paddle.device.cuda.current_stream(),
num_sms,
smem_size,
)
# Run the kernel.
runtime(*args)

View File

@@ -1,329 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""m grouped gemm"""
import functools
from typing import Tuple
import paddle
from paddle import Tensor
from .gemm import get_best_configs
from .tuner import jit_tuner
from .utils import (
get_col_major_tma_aligned_tensor,
get_col_major_tma_aligned_tensor_prefill,
get_num_sms,
)
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"',)
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated grouped GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
from collections import OrderedDict
class LRUCache:
"""
A LRUCache
"""
def __init__(self, capacity: int):
self.capacity = capacity
self.cache = OrderedDict()
def get(self, key):
"""
get keys from lru cache
"""
if key in self.cache:
# 若键存在,将其移动到 OrderedDict 的末尾,表示最近使用
self.cache.move_to_end(key)
return self.cache[key]
return None
def put(self, key, value):
"""
put keys in lru cache
"""
if key in self.cache:
# 若键已存在,先移除
del self.cache[key]
elif len(self.cache) == self.capacity:
# 若缓存已满,移除最旧的项
self.cache.popitem(last=False)
# 插入新的键值对,并将其置于 OrderedDict 的末尾
self.cache[key] = value
grouped_gemm_masked_keys = LRUCache(10)
@functools.lru_cache()
def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, num_sms):
"""auto tuning gemm"""
global includes, template
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
m, n, k, 1, num_sms, is_grouped_contiguous=True
)
runtime = jit_tuner.compile_and_tune(
m,
n,
k,
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"GEMM_TYPE": "GroupedContiguous",
"K": k,
"N": n,
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": num_tma_multicast,
},
space=(),
includes=includes,
arg_defs=(
("lhs", paddle.float8_e4m3fn),
("lhs_scales", paddle.float32),
("rhs", paddle.float8_e4m3fn),
("rhs_scales", paddle.float32),
("out", paddle.bfloat16),
("grouped_layout", paddle.int32),
("m", int),
("num_groups", int),
("stream", paddle.device.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=template,
)
return runtime, num_sms, smem_size
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
lhs: Tuple[Tensor, Tensor],
rhs: Tuple[Tensor, Tensor],
out: Tensor,
m_indices: Tensor,
) -> None:
"""
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output,
with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow Paddle operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
m_indices: a tensor of shape `[m_sum]` with type `paddle.int32`.
`m_indices[i]` records the group which the j-th row of the LHS belong to,
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
Values of `m_indices` in every-m-alignment-block must also be the same.
`-1` in this tensor indicates no RHS matrix selected,
the kernel will skip the computation for that aligned block.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
num_groups, n, k_ = rhs.shape
# TODO: NOT NEED get_col_major_tma_aligned_tensor!!!
lhs_scales = get_col_major_tma_aligned_tensor_prefill(lhs_scales)
# Do nothing if `m` is zero
if m == 0:
return
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
runtime, num_sms, smem_size = auto_tuning_with_compilation_grouped_gemm_contiguous(
m, n, k, num_groups, num_sms
)
args = (
lhs,
lhs_scales,
rhs,
rhs_scales,
out,
m_indices,
m,
num_groups,
paddle.device.cuda.current_stream(),
num_sms,
smem_size,
)
runtime(*args)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(
lhs: Tuple[Tensor, Tensor],
rhs: Tuple[Tensor, Tensor],
out: Tensor,
masked_m: Tensor,
expected_m: int,
) -> None:
"""
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow Paddle operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
in the i-th group.
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
correctly setting this value may lead to better performance.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
num_groups, m, k = lhs.shape
num_groups_, n, k_ = rhs.shape
# num_groups__, m_, n_ = out.shape
# assert len(masked_m.shape) == 1
# num_groups___ = masked_m.shape[0]
# Type and shape checks
# assert num_groups == num_groups_ == num_groups__ == num_groups___
# assert m == m_ and n == n_ and k == k_
# assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
# assert lhs_scales.shape == [num_groups, m, (k + 127) // 128]
# assert rhs_scales.shape == [num_groups, (n + 127) // 128, (k + 127) // 128]
# assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
# assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
# assert out.dtype == paddle.bfloat16
# assert masked_m.dtype == paddle.int32
# assert lhs.is_contiguous() and rhs.is_contiguous()
# assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
# assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
input_keys = (expected_m, n, k, num_groups, num_sms)
if grouped_gemm_masked_keys.get(input_keys) is None:
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
expected_m, n, k, num_groups, num_sms
)
args = (
lhs,
lhs_scales,
rhs,
rhs_scales,
out,
masked_m,
m,
paddle.device.cuda.current_stream(),
num_sms,
smem_size,
)
runtime = jit_tuner.compile_and_tune_group_gemm_masked(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": num_tma_multicast,
"GEMM_TYPE": "GroupedMasked",
},
space=(),
includes=includes,
arg_defs=(
("lhs", paddle.float8_e4m3fn),
("lhs_scales", paddle.float32),
("rhs", paddle.float8_e4m3fn),
("rhs_scales", paddle.float32),
("out", paddle.bfloat16),
("grouped_layout", paddle.int32),
("m", int),
("stream", paddle.device.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=template,
args=args,
)
grouped_gemm_masked_keys.put(input_keys, (runtime, smem_size))
else:
runtime, smem_size = grouped_gemm_masked_keys.get(input_keys)
args = (
lhs,
lhs_scales,
rhs,
rhs_scales,
out,
masked_m,
m,
paddle.device.cuda.current_stream(),
num_sms,
smem_size,
)
# Extra checks for TMA store
# if num_groups > 1 and m > block_m:
# assert (
# m % block_m == 0
# ), f"For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})"
# Run the kernel
runtime(*args)

View File

@@ -1,181 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""tune gemm kernels"""
import copy
import os
from typing import Any, Dict
import paddle
from ..jit import Runtime, build, cpp_format, generate
class JITTuner:
"""A tuner that compiles and auto-tunes group gemm masked kernels"""
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune_group_gemm_masked(
self,
name: str,
keys: Dict[str, Any],
space: tuple,
includes: tuple,
arg_defs: tuple,
template: str,
args: tuple,
) -> Runtime:
"""Compile and tune a group gemm masked kernel"""
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f"{keys}")
if signature in self.tuned:
if os.getenv("DG_JIT_DEBUG", None):
print(f"Using cached JIT kernel {name} with keys {keys}")
return self.tuned[signature]
if os.getenv("DG_JIT_DEBUG", None):
print(f"Auto-tuning JIT kernel {name} with keys {keys}")
assert signature not in self.tuned
assert args is not None
space = (dict(),) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(includes, arg_defs, cpp_format(template, full_keys))
# Illegal build must raise errors
kernels.append((build(name, arg_defs, code), tuned_keys))
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(*args)
if return_code != 0:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv("DG_JIT_DEBUG", None):
print(
f"Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: "
f"error code {return_code}"
)
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = paddle.device.cuda.Event(enable_timing=True)
end_event = paddle.device.cuda.Event(enable_timing=True)
paddle.empty(int(256e6 // 4), dtype=paddle.int32).zero_()
paddle.randn(
(8192, 8192), dtype=paddle.float32, device="cuda"
) @ paddle.randn((8192, 8192), dtype=paddle.float32)
start_event.record()
for i in range(20):
assert runtime(*args) == 0
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv("DG_JIT_DEBUG", None):
print(
f"Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}"
)
assert (
best_runtime is not None
), f"Failed to tune JIT kernel {name} with keys {keys}"
# Cache the best runtime and return
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_PRINT_AUTOTUNE", None):
print(
f"Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}"
)
self.tuned[signature] = best_runtime
return best_runtime
def compile_and_tune(
self,
m,
n,
k,
name: str,
keys: Dict[str, Any],
space: tuple,
includes: tuple,
arg_defs: tuple,
template: str,
# args: tuple,
) -> Runtime:
"""Compile and tune a kernel"""
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects
signature = (name, m, k, n)
if signature in self.tuned:
return self.tuned[signature]
# keys = {k: keys[k] for k in sorted(keys.keys())}
# signature = (name, f"{keys}")
# if signature in self.tuned:
# return self.tuned[signature]
space = (dict(),) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(includes, arg_defs, cpp_format(template, full_keys))
# Illegal build must raise errors
kernels.append((build(name, arg_defs, code), tuned_keys))
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv("DG_JIT_DEBUG", None):
print(
f"Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}"
)
assert (
best_runtime is not None
), f"Failed to tune JIT kernel {name} with keys {keys}"
# Cache the best runtime and return
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_PRINT_AUTOTUNE", None):
print(
f"Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}"
)
self.tuned[signature] = best_runtime
return best_runtime
jit_tuner = JITTuner()

View File

@@ -1,151 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""Utility functions"""
import paddle
from paddle import Tensor
_num_sms = None
def set_num_sms(num_sms: int) -> None:
"""
Set the maximum SM count for all GEMM kernels to use.
Arguments:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
global _num_sms
assert (
0
< num_sms
<= paddle.device.cuda.get_device_properties(device="cuda").multi_processor_count
)
_num_sms = num_sms
def get_num_sms() -> int:
"""
Get the current maximum limit of SM count for all GEMM kernels to use.
If the count is never specified, the function will return the number of device SMs.
Returns:
Current maximum limit of SM count for all GEMM kernels to use.
"""
global _num_sms
if _num_sms is None:
_num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
return _num_sms
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def get_m_alignment_for_contiguous_layout():
"""
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
with GEMM block shape.
Returns:
Group-level alignment requirement for grouped contiguous layout, which is always 128.
"""
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
return x
def get_col_major_tma_aligned_tensor_prefill(x: Tensor) -> Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
# assert x.dim() in (2, 3)
# remove_dim = False
# if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b, m, n = x.shape
aligned_m = get_tma_aligned_size(m, x.element_size())
# The last kernel gives a column-major TMA aligned layout
# if (
# x.strides[0] == aligned_m * n
# and x.strides[1] == 1
# and x.strides[2] == aligned_m
# ):
# return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = paddle.transpose(
paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0)
# return aligned_x.squeeze(0) if remove_dim else aligned_x

View File

@@ -1,137 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""Utilities"""
import os
import sys
import paddle
def bench(fn, num_warmups: int = 5, num_tests: int = 10, high_precision: bool = False):
"""Benchmark function `fn` using CUDA events."""
# Flush L2 cache with 256 MB data
paddle.device.cuda.synchronize()
cache = paddle.empty(int(256e6 // 4), dtype=paddle.int32)
cache.zero_()
# Warmup
for _ in range(num_warmups):
fn()
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
x = paddle.randn((8192, 8192), dtype=paddle.float32)
y = paddle.randn((8192, 8192), dtype=paddle.float32)
x @ y
# Testing
start_event = paddle.device.cuda.Event(enable_timing=True)
end_event = paddle.device.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_tests):
fn()
end_event.record()
paddle.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_tests
def get_cuda_home():
"""Get Cuda home directory"""
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
return cuda_home
try:
which_cmd = "which nvcc"
nvcc_path = os.popen(which_cmd).read().strip()
if nvcc_path:
return os.path.dirname(os.path.dirname(nvcc_path))
except Exception:
pass
return None
class EmptySuppress:
"""Empty context manager"""
def __enter__(self):
"""Empty context manager"""
return self
def __exit__(self, *_):
"""Empty exit method"""
pass
class SuppressStdoutStderr:
"""Context manager that redirects stdout and stderr"""
def __enter__(self):
"""Suppress stdout and stderr"""
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
"""Restore stdout and stderr"""
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def calc_diff(x, y):
"""Calculate difference between two vectors"""
x, y = x.astype(paddle.float64), y.astype(paddle.float64)
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def count_bytes(tensors):
"""Count number of bytes used by tensors"""
total = 0
for t in tensors:
if isinstance(t, tuple):
total += count_bytes(t)
else:
total += t.numel() * t.element_size()
return total

View File

@@ -1,110 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
"""Setup script for deep_gemm"""
import os
import shutil
import subprocess
import setuptools
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
current_dir = os.path.dirname(os.path.realpath(__file__))
jit_include_dirs = ("deep_gemm/include/deep_gemm",)
third_party_include_dirs = (
"../../third_party/cutlass/include/cute",
"../../third_party/cutlass/include/cutlass",
)
class PostDevelopCommand(develop):
"""Custom develop command that makes symbolic links to third-party include directories"""
def run(self):
"""Run the custom develop command"""
develop.run(self)
self.make_jit_include_symlinks()
@staticmethod
def make_jit_include_symlinks():
"""Make symbolic links of jit include directories"""
# Make symbolic links of third-party include directories
for d in third_party_include_dirs:
dirname = d.split("/")[-1]
src_dir = f"{current_dir}/{d}"
dst_dir = f"{current_dir}/deep_gemm/include/{dirname}"
assert os.path.exists(src_dir)
if os.path.exists(dst_dir):
assert os.path.islink(dst_dir)
os.unlink(dst_dir)
os.symlink(src_dir, dst_dir, target_is_directory=True)
class CustomBuildPy(build_py):
"""Custom build command that prepares the include files before building"""
def run(self):
"""Run the custom build command"""
# First, prepare the include directories
self.prepare_includes()
# Then run the regular build
build_py.run(self)
def prepare_includes(self):
"""Prepare the include directories"""
# Create temporary build directory instead of modifying package directory
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
os.makedirs(build_include_dir, exist_ok=True)
# Copy third-party includes to the build directory
for d in third_party_include_dirs:
dirname = d.split("/")[-1]
src_dir = os.path.join(current_dir, d)
dst_dir = os.path.join(build_include_dir, dirname)
# Remove existing directory if it exists
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# Copy the directory
shutil.copytree(src_dir, dst_dir)
if __name__ == "__main__":
# noinspection PyBroadException
try:
cmd = ["git", "rev-parse", "--short", "HEAD"]
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
except:
revision = ""
setuptools.setup(
name="deep_gemm",
version="1.0.0" + revision,
packages=["deep_gemm", "deep_gemm/jit", "deep_gemm/jit_kernels"],
package_data={
"deep_gemm": [
"include/deep_gemm/**/*",
"include/cute/**/*",
"include/cutlass/**/*",
]
},
cmdclass={
"develop": PostDevelopCommand,
"build_py": CustomBuildPy,
},
)

View File

@@ -1,205 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from DeepSeek DeepGEMM project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
import random
from typing import Tuple
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
import paddle
from fastdeploy.model_executor.ops.gpu.deep_gemm import (
calc_diff,
ceil_div,
get_col_major_tma_aligned_tensor,
)
from paddle import Tensor
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
assert x.dim() == 2 and x.shape[1] % 128 == 0
m, n = x.shape
x_view = paddle.view(x, (m, -1, 128))
x_abs = paddle.abs(x_view).astype(paddle.float32)
x_amax = paddle.amax(x_abs, axis=2)
x_amax = paddle.view(x_amax, (m, -1))
x_amax = paddle.clip(x_amax, min=1e-4)
scaled_x = x_view * (448.0 / x_amax.unsqueeze(2))
scaled_x_converted = paddle.view(scaled_x.astype(paddle.float8_e4m3fn), (m, n))
x_amax_scaled = paddle.view((x_amax / 448.0), (m, -1))
result = (scaled_x_converted, x_amax_scaled)
return result
def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = paddle.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype
)
x_padded[:m, :n] = x
x_view = paddle.view(x_padded, (-1, 128, x_padded.shape[1] // 128, 128))
x_abs = paddle.abs(x_view).astype(paddle.float32)
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
x_amax = paddle.clip(x_amax, min=1e-4)
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2]))
)
def construct(
m: int, k: int, n: int
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor, Tensor]:
x = paddle.randn((m, k), dtype=paddle.bfloat16)
y = paddle.randn((n, k), dtype=paddle.bfloat16)
out = paddle.empty((m, n), dtype=paddle.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def construct_grouped(
num_groups: int, m: int, k: int, n: int, is_masked: bool
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor, Tensor]:
# x_np = np.full((num_groups, m, k), 3)
# y_np = np.full((num_groups, n, k), 2)
# x=paddle.to_tensor(x_np).astype(paddle.bfloat16)
# y=paddle.to_tensor(y_np).astype(paddle.bfloat16)
x = paddle.randn((num_groups, m, k), dtype=paddle.bfloat16)
y = paddle.randn((num_groups, n, k), dtype=paddle.bfloat16)
out = paddle.empty((num_groups, m, n), dtype=paddle.bfloat16)
ref_out = paddle.einsum("gmk,gnk->gmn", x, y)
assert m % 4 == 0, f"TMA alignment error: {m}"
x_fp8 = (
paddle.empty_like(x, dtype=paddle.float8_e4m3fn),
paddle.empty((num_groups, m, k // 128), dtype=paddle.float32),
)
y_fp8 = (
paddle.empty_like(y, dtype=paddle.float8_e4m3fn),
paddle.empty((num_groups, (n + 127) // 128, k // 128), dtype=paddle.float32),
)
for i in range(num_groups):
# x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
# y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
x_fp8_0_i, x_fp8_1_i = per_token_cast_to_fp8(x[i])
paddle.assign(x_fp8_0_i, x_fp8[0][i])
paddle.assign(x_fp8_1_i, x_fp8[1][i])
y_fp8_0_i, y_fp8_1_i = per_block_cast_to_fp8(y[i])
paddle.assign(y_fp8_0_i, y_fp8[0][i])
paddle.assign(y_fp8_1_i, y_fp8[1][i])
# For non-masked input, we must merge the group and M dims
if not is_masked:
x_fp8 = (
paddle.view(x_fp8[0], (-1, k)),
per_token_cast_to_fp8(paddle.view(x, (-1, k)))[1],
)
out, ref_out = paddle.view(out, (-1, n)), paddle.view(ref_out, (-1, n))
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def test_gemm() -> None:
print("Testing GEMM:")
for m in (64,):
for k, n in [
(7168, 2112),
]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}"
print()
def test_m_grouped_gemm_contiguous() -> None:
print("Testing grouped contiguous GEMM:")
for num_groups, m, k, n in ((4, 8192, 7168, 4096),):
# TODO: make a stronger test
x_fp8, y_fp8, out, ref_out = construct_grouped(
num_groups, m, k, n, is_masked=False
)
m_indices = paddle.arange(0, num_groups, dtype=paddle.int32)
# m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
m_indices = paddle.flatten(
paddle.expand(paddle.unsqueeze(m_indices, -1), shape=[num_groups, m])
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
x_fp8, y_fp8, out, m_indices
)
diff = calc_diff(out, ref_out)
print("diff:", diff)
assert diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}"
print()
def test_m_grouped_gemm_masked() -> None:
print("Testing grouped masked GEMM:")
for num_groups, m in ((1, 1024),):
for k, n in ((7168, 4096),):
# Test correctness
masked_m_candidates = list(
filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))
)
for i in range(10):
x_fp8, y_fp8, out, ref_out = construct_grouped(
num_groups, m, k, n, is_masked=True
)
masked_m = paddle.empty((num_groups,), dtype=paddle.int32)
for j in range(num_groups):
masked_m[j] = random.choice(masked_m_candidates)
# expected_m = min(int(masked_m.float().mean()) + 1, m)
masked_m_float = paddle.cast(masked_m, "float32")
masked_m_mean = paddle.mean(masked_m_float)
masked_m_mean_int = paddle.cast(masked_m_mean, "int32")
expected_m = min(masked_m_mean_int + 1, m)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
x_fp8, y_fp8, out, masked_m, expected_m
)
for j in range(num_groups):
diff = calc_diff(
out[j, : masked_m[j].item()], ref_out[j, : masked_m[j].item()]
)
print("diff:", diff)
assert (
diff < 0.001
), f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}"
print()
if __name__ == "__main__":
paddle.seed(0)
random.seed(0)
print("Library path:")
print(f" > {deep_gemm.__path__}\n")
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()

View File

@@ -14,27 +14,21 @@
#include <iostream>
#include "fp8_common.h" // NOLINT
#include "fp8_gemm_fused/fuse_block_gemm_act_template_3x.h"
#include "fp8_common.h" // NOLINT
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
#include "fp8_gemm_fused/fuse_block_gemm_act_template_3x.h"
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::Tensor& x_scale,
const paddle::Tensor& y_scale,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
std::string output_dtype,
std::string activation_type) {
const paddle::Tensor &x, const paddle::Tensor &y,
const paddle::Tensor &x_scale, const paddle::Tensor &y_scale,
const paddle::optional<paddle::Tensor> &bias, bool trans_x, bool trans_y,
std::string output_dtype, std::string activation_type) {
paddle::Tensor out;
void* out_ptr = nullptr;
const void* x_ptr = nullptr;
const void* x_scale_ptr = nullptr;
const void* y_ptr = nullptr;
const void* y_scale_ptr = nullptr;
void *out_ptr = nullptr;
const void *x_ptr = nullptr;
const void *x_scale_ptr = nullptr;
const void *y_ptr = nullptr;
const void *y_scale_ptr = nullptr;
auto place = x.place();
cudaStream_t stream = x.stream();
@@ -46,7 +40,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
int K = 0;
int N = 0;
int ldd = 0;
int lda = x.dims()[rank - 1];
int ldb = y.dims()[rank - 1];
@@ -72,16 +66,16 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
}
std::string input_dtype = "";
x_scale_ptr = reinterpret_cast<const void*>(x_scale.data<float>());
y_scale_ptr = reinterpret_cast<const void*>(y_scale.data<float>());
x_scale_ptr = reinterpret_cast<const void *>(x_scale.data<float>());
y_scale_ptr = reinterpret_cast<const void *>(y_scale.data<float>());
if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) {
input_dtype = "float8_e4m3fn";
x_ptr = reinterpret_cast<const void*>(x.data<phi::dtype::float8_e4m3fn>());
y_ptr = reinterpret_cast<const void*>(y.data<phi::dtype::float8_e4m3fn>());
x_ptr = reinterpret_cast<const void *>(x.data<phi::dtype::float8_e4m3fn>());
y_ptr = reinterpret_cast<const void *>(y.data<phi::dtype::float8_e4m3fn>());
} else if (x.dtype() == phi::DataType::FLOAT8_E5M2) {
input_dtype = "float8_e5m2";
x_ptr = reinterpret_cast<const void*>(x.data<phi::dtype::float8_e5m2>());
y_ptr = reinterpret_cast<const void*>(y.data<phi::dtype::float8_e5m2>());
x_ptr = reinterpret_cast<const void *>(x.data<phi::dtype::float8_e5m2>());
y_ptr = reinterpret_cast<const void *>(y.data<phi::dtype::float8_e5m2>());
} else {
PADDLE_THROW(phi::errors::Fatal(
"fp8_fp8_half_gemm_fused only support e4m3 and e5m2 input"));
@@ -93,10 +87,10 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
if (output_dtype == "bfloat16") {
out = paddle::empty(out_shape, paddle::DataType::BFLOAT16, x.place());
out_ptr = reinterpret_cast<void*>(out.data<phi::dtype::bfloat16>());
out_ptr = reinterpret_cast<void *>(out.data<phi::dtype::bfloat16>());
} else if (output_dtype == "float16") {
out = paddle::empty(out_shape, paddle::DataType::FLOAT16, x.place());
out_ptr = reinterpret_cast<void*>(out.data<phi::dtype::float16>());
out_ptr = reinterpret_cast<void *>(out.data<phi::dtype::float16>());
} else {
PADDLE_THROW(phi::errors::Fatal(
"fp8_fp8_half_gemm_fused only support bfloat16 and float16 output"));
@@ -110,84 +104,68 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
std::string fuse_gemm_config =
input_dtype + "_" + output_dtype + "_" + isbias + "_" + act;
void* bias_data = nullptr;
void *bias_data = nullptr;
std::vector<int64_t> bias_dims{};
if (bias) {
bias_dims = common::vectorize(bias.get().dims());
if (output_dtype == "bfloat16") {
bias_data = reinterpret_cast<void*>(const_cast<phi::dtype::bfloat16*>(
bias_data = reinterpret_cast<void *>(const_cast<phi::dtype::bfloat16 *>(
bias.get().data<phi::dtype::bfloat16>()));
} else {
bias_data = reinterpret_cast<void*>(const_cast<phi::dtype::float16*>(
bias_data = reinterpret_cast<void *>(const_cast<phi::dtype::float16 *>(
bias.get().data<phi::dtype::float16>()));
}
}
GemmEpilogueAllParams params = {x_ptr,
y_ptr,
out_ptr,
1,
M,
N,
K,
lda,
ldb,
ldd,
batch_count,
place,
stream,
sm_version,
0.01, // for leaky_relu
bias_data,
bias_dims,
fuse_gemm_config,
0,
nullptr,
nullptr,
x_scale_ptr,
y_scale_ptr};
GemmEpilogueAllParams params = {x_ptr, y_ptr, out_ptr,
1, M, N,
K, lda, ldb,
ldd, batch_count, place,
stream, sm_version,
0.01, // for leaky_relu
bias_data, bias_dims, fuse_gemm_config,
0, nullptr, nullptr,
x_scale_ptr, y_scale_ptr};
fp8_fp8_block_gemm_scale_bias_act(params);
return {out};
}
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& y_shape,
const std::vector<int64_t>& x_scale_shape,
const std::vector<int64_t>& y_scale_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
bool trans_x,
bool trans_y){
PADDLE_ENFORCE_EQ(x_shape.size(),
y_shape.size(),
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
const std::vector<int64_t> &x_scale_shape,
const std::vector<int64_t> &y_scale_shape,
const paddle::optional<std::vector<int64_t>> &bias_shape, bool trans_x,
bool trans_y) {
PADDLE_ENFORCE_EQ(x_shape.size(), y_shape.size(),
phi::errors::InvalidArgument(
"The rank of input X and Y should be equal, but received X's rank is %d, Y's rank is %d.",
x_shape.size(),
y_shape.size()));
PADDLE_ENFORCE_EQ(x_shape.size(),
x_scale_shape.size(),
"The rank of input X and Y should be equal, but "
"received X's rank is %d, Y's rank is %d.",
x_shape.size(), y_shape.size()));
PADDLE_ENFORCE_EQ(x_shape.size(), x_scale_shape.size(),
phi::errors::InvalidArgument(
"The rank of input X and X_scale should be equal, but received X's rank is %d, X_scale's rank is %d.",
x_shape.size(),
x_scale_shape.size()));
PADDLE_ENFORCE_EQ(y_shape.size(),
y_scale_shape.size(),
"The rank of input X and X_scale should be equal, but "
"received X's rank is %d, X_scale's rank is %d.",
x_shape.size(), x_scale_shape.size()));
PADDLE_ENFORCE_EQ(y_shape.size(), y_scale_shape.size(),
phi::errors::InvalidArgument(
"The rank of input Y and Y_scale should be equal, but received Y's rank is %d, Y_scale's rank is %d.",
y_shape.size(),
y_scale_shape.size()));
"The rank of input Y and Y_scale should be equal, but "
"received Y's rank is %d, Y_scale's rank is %d.",
y_shape.size(), y_scale_shape.size()));
int rank = x_shape.size();
int M = 0;
int N = 0;
if ((x_shape[rank - 1] + 127) / 128 != x_scale_shape[rank - 2]){
PADDLE_THROW(phi::errors::Fatal(
"cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-2] * 128 = x's dim[-1]."));
if ((x_shape[rank - 1] + 127) / 128 != x_scale_shape[rank - 2]) {
PADDLE_THROW(
phi::errors::Fatal("cutlass_fp8_fp8_half_block_gemm_fused only support "
"x_scale's dim[-2] * 128 = x's dim[-1]."));
}
if (((y_shape[rank - 1] + 127) / 128 != y_scale_shape[rank - 1]) || ((y_shape[rank - 2] + 127) / 128 != y_scale_shape[rank - 2])){
PADDLE_THROW(phi::errors::Fatal(
"cutlass_fp8_fp8_half_block_gemm_fused only support input y_scale's dim[-2:] * 128 = y's dim[-2:]."));
if (((y_shape[rank - 1] + 127) / 128 != y_scale_shape[rank - 1]) ||
((y_shape[rank - 2] + 127) / 128 != y_scale_shape[rank - 2])) {
PADDLE_THROW(
phi::errors::Fatal("cutlass_fp8_fp8_half_block_gemm_fused only support "
"input y_scale's dim[-2:] * 128 = y's dim[-2:]."));
}
if (!trans_x) {
@@ -208,31 +186,25 @@ std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
}
std::vector<paddle::DataType> CutlassFp8Fp8HalfBlockGemmFusedInferDtype(
const paddle::DataType& x_type,
const paddle::DataType& y_type,
const paddle::DataType& x_scale_type,
const paddle::DataType& y_scale_type,
const paddle::optional<paddle::DataType>& bias_type,
bool trans_x,
bool trans_y,
std::string output_dtype) {
paddle::DataType data_type;
if (output_dtype == "bfloat16")
data_type = paddle::DataType::BFLOAT16;
else if (output_dtype == "float16")
data_type = paddle::DataType::FLOAT16;
else
PD_THROW(
"cutlass_fp8_fp8_half_gemm only support bfloat16 and float16 output");
return {data_type};
const paddle::DataType &x_type, const paddle::DataType &y_type,
const paddle::DataType &x_scale_type, const paddle::DataType &y_scale_type,
const paddle::optional<paddle::DataType> &bias_type, bool trans_x,
bool trans_y, std::string output_dtype) {
paddle::DataType data_type;
if (output_dtype == "bfloat16")
data_type = paddle::DataType::BFLOAT16;
else if (output_dtype == "float16")
data_type = paddle::DataType::FLOAT16;
else
PD_THROW("cutlass_fp8_fp8_half_block_gemm only support bfloat16 and "
"float16 output");
return {data_type};
}
PD_BUILD_STATIC_OP(cutlass_fp8_fp8_half_block_gemm_fused)
.Inputs({"x", "y", "x_sacle", "y_scale", paddle::Optional("bias")})
.Attrs({"transpose_x: bool",
"transpose_y: bool",
"output_dtype: std::string",
"act: std::string"})
.Attrs({"transpose_x: bool", "transpose_y: bool",
"output_dtype: std::string", "act: std::string"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(cutlass_fp8_fp8_half_block_gemm_fused))
.SetInferShapeFn(PD_INFER_SHAPE(CutlassFp8Fp8HalfBlockGemmFusedInferShape))

View File

@@ -16,7 +16,7 @@
#include "helper.h"
namespace {
int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) {
int sharedMemoryOpen2(const char *name, size_t sz, sharedMemoryInfo *info) {
info->size = sz;
info->shmFd = shm_open(name, O_RDWR, 0777);
if (info->shmFd < 0) {
@@ -40,7 +40,7 @@ std::vector<paddle::Tensor> GetDataPtrIpc(const paddle::Tensor &tmp_input,
auto out_data_ptr_tensor_ptr = out_data_ptr_tensor.data<int64_t>();
volatile shmStruct *shm = NULL;
sharedMemoryInfo info;
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
if (sharedMemoryOpen2(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
printf("Failed to create shared memory slab\n");
printf("Func GetDataPtrIpc. Shm_name: %s\n", shm_name.c_str());
exit(EXIT_FAILURE);

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/extension.h"
#include <map>
std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
const paddle::Tensor& task_image_type_ids,
@@ -60,6 +61,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
st_idx += cur_st_len;
}
}
while (idx < seq_lens_origin) {
idx = idx + split_fuse_text_size;
if (idx >= seq_lens_origin) {

View File

@@ -70,8 +70,15 @@ void GetOutputEp(const paddle::Tensor& x,
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "msg_queue_id is: "
<< msg_queue_id << std::endl;
#endif
// static key_t key = ftok("/dev/shm", msg_queue_id);
// static int msgid = msgget(key, IPC_CREAT | 0666);
key_t key = ftok("/dev/shm", msg_queue_id);
int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;

View File

@@ -14,6 +14,7 @@
#pragma once
#include "glog/logging.h"
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
@@ -22,25 +23,25 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include "glog/logging.h"
#ifdef PADDLE_WITH_HIP
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
#include <hiprand.h>
#include <hiprand_kernel.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#endif
#include "nlohmann/json.hpp"
#include <fstream>
#include <iostream>
#include "nlohmann/json.hpp"
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/cuda_stream.h"
#include "paddle/phi/core/dense_tensor.h"
#ifndef PD_BUILD_STATIC_OP
@@ -49,218 +50,211 @@ namespace cub = hipcub;
using json = nlohmann::json;
#define CUDA_CHECK(call) \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) { \
std::printf("at %s:%d - %s.\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
#define CUDA_CHECK(call) \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) { \
std::printf("at %s:%d - %s.\n", __FILE__, __LINE__, \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
#ifdef PADDLE_WITH_HIP
template <size_t kBlockSize = 256, size_t kNumWaves = 16>
inline hipError_t GetNumBlocks(int64_t n, int *num_blocks) {
int dev;
{
hipError_t err = hipGetDevice(&dev);
if (err != hipSuccess) {
return err;
}
int dev;
{
hipError_t err = hipGetDevice(&dev);
if (err != hipSuccess) {
return err;
}
int sm_count;
{
hipError_t err = hipDeviceGetAttribute(
&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
if (err != hipSuccess) {
return err;
}
}
int sm_count;
{
hipError_t err = hipDeviceGetAttribute(
&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
if (err != hipSuccess) {
return err;
}
int tpm;
{
hipError_t err = hipDeviceGetAttribute(
&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
if (err != hipSuccess) {
return err;
}
}
int tpm;
{
hipError_t err = hipDeviceGetAttribute(
&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
if (err != hipSuccess) {
return err;
}
*num_blocks = std::max<int>(
1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return hipSuccess;
}
*num_blocks = std::max<int>(
1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return hipSuccess;
}
#else
template <size_t kBlockSize = 256, size_t kNumWaves = 16>
inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(
&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err =
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
*num_blocks = std::max<int>(
1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}
*num_blocks = std::max<int>(
1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}
inline int GetGPUComputeCapability(int id) {
int major, minor;
auto major_error_code =
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
auto minor_error_code =
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
return major * 10 + minor;
int major, minor;
auto major_error_code =
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
auto minor_error_code =
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
return major * 10 + minor;
}
#endif
template <paddle::DataType D>
class PDTraits;
inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1)
return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
template <paddle::DataType D> class PDTraits;
template <> class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
template <> class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
};
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
template <> class PDTraits<paddle::DataType::BFLOAT16> {
public:
#ifdef PADDLE_WITH_HIP
typedef hip_bfloat16 DataType;
typedef hip_bfloat16 DataType;
#else
typedef __nv_bfloat16 DataType;
typedef __nv_bfloat16 DataType;
#endif
typedef paddle::bfloat16 data_t;
typedef paddle::bfloat16 data_t;
};
template <>
class PDTraits<paddle::DataType::INT8> {
public:
typedef int8_t DataType;
typedef int8_t data_t;
template <> class PDTraits<paddle::DataType::INT8> {
public:
typedef int8_t DataType;
typedef int8_t data_t;
};
template <>
class PDTraits<paddle::DataType::UINT8> {
public:
typedef uint8_t DataType;
typedef uint8_t data_t;
template <> class PDTraits<paddle::DataType::UINT8> {
public:
typedef uint8_t DataType;
typedef uint8_t data_t;
};
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
HOSTDEVICE inline const T &operator[](int i) const { return val[i]; }
HOSTDEVICE inline T &operator[](int i) { return val[i]; }
HOSTDEVICE inline const T &operator[](int i) const { return val[i]; }
HOSTDEVICE inline T &operator[](int i) { return val[i]; }
};
template <typename T, int Size>
HOSTDEVICE inline void Load(const T *addr, AlignedVector<T, Size> *vec) {
const AlignedVector<T, Size> *addr_vec =
reinterpret_cast<const AlignedVector<T, Size> *>(addr);
*vec = *addr_vec;
const AlignedVector<T, Size> *addr_vec =
reinterpret_cast<const AlignedVector<T, Size> *>(addr);
*vec = *addr_vec;
}
template <typename T, int Size>
HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
AlignedVector<T, Size> *addr_vec =
reinterpret_cast<AlignedVector<T, Size> *>(addr);
*addr_vec = vec;
AlignedVector<T, Size> *addr_vec =
reinterpret_cast<AlignedVector<T, Size> *>(addr);
*addr_vec = vec;
}
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
}
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
int8_t *addr) {
printf("Error: Store half to int8_t is not supported!");
printf("Error: Store half to int8_t is not supported!");
}
constexpr int VEC_16B = 16;
template <typename T>
__device__ T max_func(const T a, const T b) {
return a > b ? a : b;
template <typename T> __device__ T max_func(const T a, const T b) {
return a > b ? a : b;
}
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max_func(a, b);
}
template <typename T> struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max_func(a, b);
}
};
inline int GetBlockSize(int vocab_size) {
if (vocab_size > 512) {
return 1024;
} else if (vocab_size > 256) {
return 512;
} else if (vocab_size > 128) {
return 256;
} else if (vocab_size > 64) {
return 128;
} else {
return 64;
}
if (vocab_size > 512) {
return 1024;
} else if (vocab_size > 256) {
return 512;
} else if (vocab_size > 128) {
return 256;
} else if (vocab_size > 64) {
return 128;
} else {
return 64;
}
}
inline json readJsonFromFile(const std::string &filePath) {
std::ifstream file(filePath);
if (!file.is_open()) {
throw std::runtime_error("Unable to open file: " + filePath);
}
std::ifstream file(filePath);
if (!file.is_open()) {
throw std::runtime_error("Unable to open file: " + filePath);
}
json j;
file >> j;
return j;
json j;
file >> j;
return j;
}
#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \
<< cudaGetErrorString(e) << std::endl; \
exit(EXIT_FAILURE); \
} \
}
#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \
<< cudaGetErrorString(e) << std::endl; \
exit(EXIT_FAILURE); \
} \
}
// place must be an existing place object and cannot use paddle::CPUPlace() or
// paddle::GPUPlace()
@@ -269,220 +263,194 @@ inline json readJsonFromFile(const std::string &filePath) {
inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
const paddle::DataType &dtype,
const paddle::Place &place) {
auto *allocator = paddle::GetAllocator(place);
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(
allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
auto *allocator = paddle::GetAllocator(place);
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(allocator, dtype,
dense_tensor.numel() * phi::SizeOf(dtype));
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
}
inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
const common::DDim &strides,
const paddle::DataType &dtype,
const paddle::Place &place) {
auto *allocator = paddle::GetAllocator(place);
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(
allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
dense_tensor.set_strides(strides);
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
auto *allocator = paddle::GetAllocator(place);
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(allocator, dtype,
dense_tensor.numel() * phi::SizeOf(dtype));
dense_tensor.set_strides(strides);
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
}
#endif
__global__ void free_and_dispatch_block(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num);
__global__ void free_and_dispatch_block(
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
int *block_tables, int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
const int max_decoder_block_num);
__global__ void speculate_free_and_dispatch_block(
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
int *block_tables, int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids,
int *accept_num, const int bsz, const int block_size,
const int block_num_per_seq, const int max_decoder_block_num,
const int max_draft_tokens);
__device__ bool speculate_free_and_dispatch_block(const int &qid,
int *need_block_list,
const int &need_block_len);
static std::string global_base64_chars = // NOLINT
static std::string global_base64_chars = // NOLINT
"Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg";
// Base64 编码函数
inline std::string base64_encode(const std::string &input) {
std::string ret;
int i = 0;
int j = 0;
unsigned char char_array_3[3];
unsigned char char_array_4[4];
std::string ret;
int i = 0;
int j = 0;
unsigned char char_array_3[3];
unsigned char char_array_4[4];
for (const auto &c : input) {
char_array_3[i++] = c;
if (i == 3) {
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) +
((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) +
((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (const auto &c : input) {
char_array_3[i++] = c;
if (i == 3) {
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] =
((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] =
((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (i = 0; i < 4; i++) {
ret += global_base64_chars[char_array_4[i]];
}
i = 0;
}
for (i = 0; i < 4; i++) {
ret += global_base64_chars[char_array_4[i]];
}
i = 0;
}
}
if (i) {
for (j = i; j < 3; j++) {
char_array_3[j] = '\0';
}
if (i) {
for (j = i; j < 3; j++) {
char_array_3[j] = '\0';
}
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] =
((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] =
((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] =
((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] =
((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (j = 0; j < i + 1; j++) {
ret += global_base64_chars[char_array_4[j]];
}
while (i++ < 3) {
ret += '=';
}
for (j = 0; j < i + 1; j++) {
ret += global_base64_chars[char_array_4[j]];
}
return ret;
while (i++ < 3) {
ret += '=';
}
}
return ret;
}
// Base64 解码函数
inline std::string base64_decode(const std::string &encoded_string) {
int in_len = encoded_string.size();
int i = 0;
int j = 0;
int in_ = 0;
unsigned char char_array_4[4], char_array_3[3];
std::string ret;
int in_len = encoded_string.size();
int i = 0;
int j = 0;
int in_ = 0;
unsigned char char_array_4[4], char_array_3[3];
std::string ret;
while (in_len-- && (encoded_string[in_] != '=') &&
(isalnum(encoded_string[in_]) || (encoded_string[in_] == '+') ||
(encoded_string[in_] == '/'))) {
char_array_4[i++] = encoded_string[in_];
in_++;
if (i == 4) {
for (i = 0; i < 4; i++) {
char_array_4[i] = global_base64_chars.find(char_array_4[i]);
}
while (in_len-- && (encoded_string[in_] != '=') &&
(isalnum(encoded_string[in_]) || (encoded_string[in_] == '+') ||
(encoded_string[in_] == '/'))) {
char_array_4[i++] = encoded_string[in_];
in_++;
if (i == 4) {
for (i = 0; i < 4; i++) {
char_array_4[i] = global_base64_chars.find(char_array_4[i]);
}
char_array_3[0] =
(char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) +
((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
char_array_3[0] =
(char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] =
((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; i < 3; i++) {
ret += char_array_3[i];
}
i = 0;
}
for (i = 0; i < 3; i++) {
ret += char_array_3[i];
}
i = 0;
}
}
if (i) {
for (j = i; j < 4; j++) {
char_array_4[j] = 0;
}
if (i) {
for (j = i; j < 4; j++) {
char_array_4[j] = 0;
}
for (j = 0; j < 4; j++) {
char_array_4[j] = global_base64_chars.find(char_array_4[j]);
}
char_array_3[0] =
(char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] =
((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; j < i - 1; j++) {
ret += char_array_3[j];
}
for (j = 0; j < 4; j++) {
char_array_4[j] = global_base64_chars.find(char_array_4[j]);
}
return ret;
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] =
((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; j < i - 1; j++) {
ret += char_array_3[j];
}
}
return ret;
}
template <typename T>
inline T get_relative_best(nlohmann::json *json_data,
const std::string &target_key,
const T &default_value) {
if (json_data->contains(target_key)) {
return json_data->at(target_key);
} else {
// std::cerr << "The key " << target_key << " is not found in the JSON
// data." << std::endl;
return default_value;
}
if (json_data->contains(target_key)) {
return json_data->at(target_key);
} else {
// std::cerr << "The key " << target_key << " is not found in the JSON
// data." << std::endl;
return default_value;
}
}
__device__ inline bool is_in_end(const int64_t id,
const int64_t *end_ids,
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
return flag;
}
return flag;
}
template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
template <typename T> inline __device__ __host__ T div_up(T m, T n) {
return (m + n - 1) / n;
}
template<typename T>
__device__ __inline__
T ClipFunc(const T v, const T min, const T max){
if(v > max) return max;
if(v < min) return min;
return v;
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max)
return max;
if (v < min)
return min;
return v;
}
template <typename T>
@@ -499,31 +467,49 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
if (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
ss << static_cast<int>(tmp[i]) << std::endl;
} else {
ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT
ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT
}
}
outfile << ss.str();
outfile.close();
}
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr, int mode = 0) {
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) {
uint32_t flag;
if (mode == 0) {
asm volatile("ld.acquire.sys.global.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
asm volatile("ld.acquire.sys.global.b32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
} else if (mode == 1) {
asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
asm volatile("ld.acquire.gpu.global.b32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
} else {
asm volatile("ld.acquire.cta.global.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
asm volatile("ld.acquire.cta.global.b32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
}
return flag;
}
__forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, uint32_t flag, int mode = 0) {
__forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
uint32_t flag, int mode = 0) {
if (mode == 0) {
asm volatile("st.release.sys.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
asm volatile("st.release.sys.global.b32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
} else if (mode == 1) {
asm volatile("st.release.gpu.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
asm volatile("st.release.gpu.global.b32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
} else {
asm volatile("st.release.cta.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
asm volatile("st.release.cta.global.b32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
}
}
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
return max_shared_mem_per_block_opt_in;
}

View File

@@ -1,3 +1,4 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,19 +12,67 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fstream"
#include "helper.h"
#include "iomanip"
#include "fstream"
#include "iostream"
#include "iomanip"
#include <nvml.h>
#include <iostream>
#include <nvml.h>
// #define PRINT_GPU_MEMORY
// 函数用于获取 NVIDIA GPU 显存信息
bool getNvidiaGPUMemoryUsage(int callLine) {
#ifndef PRINT_GPU_MEMORY
return true;
#endif
// 初始化 NVML
nvmlReturn_t result;
result = nvmlInit();
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to initialize NVML: " << nvmlErrorString(result) << std::endl;
return false;
}
// 获取 GPU 设备数量
unsigned int deviceCount;
result = nvmlDeviceGetCount(&deviceCount);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get device count: " << nvmlErrorString(result) << std::endl;
nvmlShutdown();
return false;
}
// 遍历每个 GPU 设备
for (unsigned int i = 0; i < deviceCount; ++i) {
nvmlDevice_t device;
result = nvmlDeviceGetHandleByIndex(i, &device);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get device handle for device " << i << ": " << nvmlErrorString(result) << std::endl;
continue;
}
// 获取显存信息
nvmlMemory_t memory;
result = nvmlDeviceGetMemoryInfo(device, &memory);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get memory info for device " << i << ": " << nvmlErrorString(result) << std::endl;
continue;
}
// 只打印一行信息并显示调用函数时的行号
std::cout << callLine << ": GPU " << i << " - Total: " << memory.total / (1024 * 1024)
<< " MiB, Used: " << memory.used / (1024 * 1024)
<< " MiB, Free: " << memory.free / (1024 * 1024) << " MiB" << std::endl;
}
// 清理 NVML 资源
nvmlShutdown();
return true;
}
// #define DEBUG_IPC_SENT
// #define DEBUG_IPC_SENT_SYNC_AND_PRINT
template <typename T>
template<typename T>
void sent_key_value_by_remote_ptr(
const T* local_key_tensor_base_ptr, // gpu ptr
const T* local_value_tensor_base_ptr, // gpu ptr
const int32_t* local_block_ids_ptr, // cpu ptr,
const T* local_key_tensor_base_ptr, // gpu ptr
const T* local_value_tensor_base_ptr, // gpu ptr
const int32_t* local_block_ids_ptr, //cpu ptr,
const int32_t* remote_block_ids_ptr,
const int32_t block_num,
const int64_t block_idx_stride,
@@ -32,170 +81,163 @@ void sent_key_value_by_remote_ptr(
const int32_t remote_device_id,
T* remote_key_tensor_base_ptr, // gpu ptr
T* remote_value_tensor_base_ptr, // gpu ptr
cudaStream_t stream) {
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const T* local_key_tensor_sent_ptr =
local_key_tensor_base_ptr +
local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_key_tensor_sent_ptr =
remote_key_tensor_base_ptr +
remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout << "remote_key_tensor_sent_ptr:"
<< (int64_t)remote_key_tensor_sent_ptr
<< " local_key_tensor_sent_ptr:"
<< (int64_t)local_key_tensor_sent_ptr
<< " local_device_id:" << local_device_id
<< " remote_device_id:" << remote_device_id
<< " block_idx_stride:" << block_idx_stride
<< " block_size_byte:" << block_size_byte
<< " stream: " << stream
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
<< std::endl;
#endif
cudaStream_t stream){
for(int block_idx=0;block_idx < block_num; ++block_idx){
const T* local_key_tensor_sent_ptr = local_key_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_key_tensor_sent_ptr = remote_key_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
<<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte
<<" stream: " << stream
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
<<std::endl;
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(
reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
128 * 1,
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
cudaDeviceSynchronize();
PrintMatrix<T>(reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
128 * 1,
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte);
#endif
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(
reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
const T* local_value_tensor_sent_ptr =
local_value_tensor_base_ptr +
local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_value_tensor_sent_ptr =
remote_value_tensor_base_ptr +
remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout << "remote_value_tensor_sent_ptr:"
<< (int64_t)remote_value_tensor_sent_ptr
<< " local_value_tensor_sent_ptr:"
<< (int64_t)local_value_tensor_sent_ptr
<< " local_device_id:" << local_device_id
<< " remote_device_id:" << remote_device_id
<< " block_idx_stride:" << block_idx_stride
<< " block_size_byte:" << block_size_byte
<< " stream: " << stream
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
<< std::endl;
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(
reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
128 * 1,
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte,
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte);
#endif
cudaError_t err = cudaGetLastError();
if ( err != cudaSuccess )
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
const T* local_value_tensor_sent_ptr = local_value_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_value_tensor_sent_ptr = remote_value_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
<<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte
<<" stream: " << stream
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
<<std::endl;
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
128 * 1,
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte);
cudaDeviceSynchronize();
#endif
err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
if ( err != cudaSuccess )
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
PrintMatrix<T>(
reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
128 * 1);
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
}
}
void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
const paddle::Tensor& local_value_tensor,
const paddle::Tensor& local_block_ids, // cpu
const paddle::Tensor& remote_block_ids, // cpu
const paddle::Tensor& local_block_ids, // cpu
const paddle::Tensor& remote_block_ids, // cpu
const paddle::Tensor& remote_key_tensor,
const paddle::Tensor& remote_value_tensor,
const int& block_num,
const int& local_device_id,
const int& remote_device_id) {
const int& remote_device_id,
const int64_t& cuda_stream_raw) {
std::vector<int64_t> cache_key_tensor_shape = local_key_tensor.shape();
auto cuda_stream = local_key_tensor.stream();
// const cudaStream_t cuda_stream = *(reinterpret_cast<const
// cudaStream_t*>(&stream));
#ifdef DEBUG_IPC_SENT
std::cout << "#### 000" << std::endl;
#endif
getNvidiaGPUMemoryUsage(__LINE__);
// auto cuda_stream = local_key_tensor.stream();
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
getNvidiaGPUMemoryUsage(__LINE__);
// const cudaStream_t cuda_stream = *(reinterpret_cast<const cudaStream_t*>(&stream));
#ifdef DEBUG_IPC_SENT
std::cout<<"#### 000"<<std::endl;
#endif
int32_t total_block_num_local = cache_key_tensor_shape[0];
int32_t kv_num_head_local = cache_key_tensor_shape[1];
int32_t block_size_local = cache_key_tensor_shape[2];
int32_t hidden_size_local = cache_key_tensor_shape[3];
getNvidiaGPUMemoryUsage(__LINE__);
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
getNvidiaGPUMemoryUsage(__LINE__);
#ifdef DEBUG_IPC_SENT
std::cout << "#### 1111"
<< " remote_key_ptr: " << remote_key_ptr
<< " remote_value_ptr: " << remote_value_ptr << std::endl;
#endif
int64_t block_idx_stride =
kv_num_head_local * block_size_local * hidden_size_local;
#ifdef DEBUG_IPC_SENT
std::cout<<"#### 1111"
<< " remote_key_ptr: "<<remote_key_ptr
<< " remote_value_ptr: "<<remote_value_ptr<<std::endl;
#endif
getNvidiaGPUMemoryUsage(__LINE__);
int64_t block_idx_stride = kv_num_head_local*block_size_local*hidden_size_local;
auto local_key_tensor_ptr = local_key_tensor.data();
auto local_value_tensor_ptr = local_value_tensor.data();
#ifdef DEBUG_IPC_SENT
std::cout << "#### 2222" << std::endl;
#endif
getNvidiaGPUMemoryUsage(__LINE__);
#ifdef DEBUG_IPC_SENT
std::cout<<"#### 2222"<<std::endl;
#endif
switch (local_key_tensor.type()) {
case paddle::DataType::BFLOAT16: {
using dataT = __nv_bfloat16;
using dataT=__nv_bfloat16;
// std::cout<<"#### cache type __nv_bfloat16" << std::endl;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
@@ -209,10 +251,11 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
cuda_stream
);
}
case paddle::DataType::FLOAT16: {
using dataT = half;
using dataT=half;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
@@ -225,10 +268,11 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
cuda_stream
);
}
case paddle::DataType::INT8: {
using dataT = int8_t;
using dataT=int8_t;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
@@ -241,10 +285,11 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
cuda_stream
);
}
case paddle::DataType::UINT8: {
using dataT = uint8_t;
using dataT=uint8_t;
// std::cout<<"#### cache type uint8" << std::endl;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
@@ -258,20 +303,33 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
cuda_stream
);
}
}
// using dataT=std::remove_pointer<decltype(local_block_ids_ptr)>;
}
void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
const paddle::Tensor& local_value_tensor,
const int64_t& cuda_stream_raw) {
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
cudaStreamSynchronize(cuda_stream);
}
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
.Inputs({"local_key_tensor",
"local_value_tensor",
"local_block_ids",
"remote_block_ids",
"remote_key_tensor",
"remote_value_tensor"})
.Attrs({"block_num: int", "local_device_id: int", "remote_device_id: int"})
.Outputs({"remote_block_ids_out"})
.SetInplaceMap({{"remote_block_ids", "remote_block_ids_out"}})
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
.Attrs({ "block_num: int",
"local_device_id: int",
"remote_device_id: int",
"cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtr));
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
.Inputs({"local_key_tensor", "local_value_tensor"})
.Attrs({"cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename scalar_t>
__global__ void
cuda_kernel(const scalar_t *__restrict__ topk_ids, int32_t *__restrict__ res,
int32_t *__restrict__ res_padded, size_t numel, int num_experts) {
extern __shared__ int32_t tokens_per_ep[];
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
tokens_per_ep[i] = 0;
}
__syncthreads();
for (size_t i = threadIdx.x; i < numel; i += blockDim.x) {
int32_t expert_id = topk_ids[i];
if(expert_id >= 0) atomicAdd(&tokens_per_ep[expert_id], 1);
}
__syncthreads();
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
res[i] = tokens_per_ep[i];
res_padded[i] = (res[i] + 127) / 128 * 128;
}
}
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
int64_t num_experts) {
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
auto token_nums_per_expert = paddle::empty(
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
auto stream = topk_ids.stream();
using scalar_t = int64_t;
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(), token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts, topk_ids_numel,
num_experts);
return token_nums_per_expert;
}

View File

@@ -30,9 +30,9 @@ namespace cg = cooperative_groups;
template<typename T>
__device__ T warpReduceSum(T val){
for(int lane_mask = 16; lane_mask > 0; lane_mask /=2){
val += __shfl_down_sync(0xffffffff, val, lane_mask);
val += __shfl_down_sync(0xffffffff, val, lane_mask);
}
return val;
return val;
}
__global__ void get_expert_token_num(
@@ -88,7 +88,7 @@ __global__ void get_expert_token_num(
sum = (threadIdx.x < KNWARPS) ? warp_sum[laneId] : 0;
sum_padded = (threadIdx.x < KNWARPS) ? warp_sum[laneId + KNWARPS] : 0;
if (warpId == 0) {
sum = warpReduceSum<int>(sum);
sum = warpReduceSum<int>(sum);
sum_padded = warpReduceSum<int>(sum_padded);
}
if (threadIdx.x == 0) {
@@ -167,7 +167,7 @@ __global__ void combine_prmt_back_kernel(
#pragma unroll
for (int vid = 0; vid < VEC_SIZE; vid++) {
res_vec[vid] += static_cast<T>(
row_scale * static_cast<float>(load_vec[vid]) +
row_scale * static_cast<float>(load_vec[vid]) +
static_cast<float>(bias_vec[vid]));
}
} else {
@@ -497,7 +497,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
place);
auto num_experts_per_rank_tensor = GetEmptyTensor(
{num_experts_per_rank},
paddle::DataType::INT32,
paddle::DataType::INT32,
place);
auto expert_idx_per_token = GetEmptyTensor(
{token_nums_this_rank}, paddle::DataType::INT64, place);
@@ -619,7 +619,7 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
"cumsum_idx_gpu",
"expert_idx_per_token"})
.Attrs({
"token_nums_per_expert: std::vector<int>",
"token_nums_per_expert: std::vector<int>",
"token_nums_this_rank: int",
"moe_quant_type: std::string"
})
@@ -672,18 +672,21 @@ __global__ void permute_x_fp8_kernel(const T *src_x,
const int hidden_size_int4 = hidden_size / vec_size;
const int hidden_size_scale = hidden_size / 128;
const int hidden_size_scale_int4 = hidden_size_scale / scale_vec_size;
const int token_nums_feed_to_ffn = token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK-1];
// prmt
for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_this_rank_padded; s_token_idx += gridDim.x) {
if (tid == 0) {
for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) {
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
const int end_idx = token_nums_per_expert_cum[i];
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
m_indices[s_token_idx] = i;
break;
}
for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_feed_to_ffn; s_token_idx += gridDim.x) {
// the m_indices[s_token_idx] must be a value `i` in [0, NUM_EXPERTS_PER_RANK)
// here we parallel wo find the `i` we want.
for (int i = threadIdx.x; i < NUM_EXPERTS_PER_RANK; i+= blockDim.x) {
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
const int end_idx = token_nums_per_expert_cum[i];
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
if ((s_token_idx - start_idx) < token_nums_per_expert[i]) m_indices[s_token_idx] = i;
break;
}
}
if (s_token_idx < num_rows) {
const int64_t *topk_idx_now = topk_idx + s_token_idx * moe_topk;
#pragma unroll
@@ -738,7 +741,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
paddle::Tensor* m_indices) {
auto stream = input.stream();
auto place = input.place();
const int gridx = min(132 * 8, num_rows);
// const int gridx = min(132 * 8, num_rows);
const int gridx = 132 * 8;
if (num_experts_per_rank == 8) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
@@ -831,8 +835,31 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else if (num_experts_per_rank == 128) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
hidden_size,
permute_input->data<phi::dtype::float8_e4m3fn>(),
permute_scale->data<float>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else {
PD_THROW("Not dispatching this num_experts_per_rank for EPMoeDispatchFP8Kernel");
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
}
}
@@ -842,10 +869,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const std::vector<int>& token_nums_per_expert,
const std::vector<int>& token_nums_per_expert_padded,
const int token_nums_this_rank,
const int token_nums_this_rank_padded) {
const paddle::Tensor& num_experts_per_rank_tensor,
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
const auto input_type = input.dtype();
const int moe_topk = topk_ids.dims()[1];
auto place = input.place();
@@ -859,7 +884,10 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
}
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];
const int num_experts_per_rank = token_nums_per_expert.size();
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
auto permute_input = GetEmptyTensor(
{token_nums_this_rank_padded, hidden_size},
@@ -869,30 +897,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
{token_nums_this_rank_padded, hidden_size / 128},
paddle::DataType::FLOAT32,
place);
auto num_experts_per_rank_tensor = GetEmptyTensor(
{num_experts_per_rank},
paddle::DataType::INT32,
place);
auto num_experts_per_rank_padded_tensor = GetEmptyTensor(
{num_experts_per_rank},
paddle::DataType::INT32,
place);
auto m_indices = GetEmptyTensor(
{token_nums_this_rank_padded},
paddle::DataType::INT32,
place);
cudaMemcpyAsync(
num_experts_per_rank_tensor.data<int>(),
token_nums_per_expert.data(),
num_experts_per_rank * sizeof(int),
cudaMemcpyHostToDevice,
input.stream());
cudaMemcpyAsync(
num_experts_per_rank_padded_tensor.data<int>(),
token_nums_per_expert_padded.data(),
num_experts_per_rank * sizeof(int),
cudaMemcpyHostToDevice,
input.stream());
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
@@ -908,8 +914,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
num_experts_per_rank_padded_tensor,
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
-1,
-1,
hidden_size,
num_experts_per_rank,
&permute_input,
@@ -932,61 +938,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
m_indices};
}
std::vector<std::vector<int64_t>> EPMoeExpertDispatchFP8InferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& scale_shape,
const std::vector<int64_t>& topk_ids_shape,
const std::vector<int64_t>& topk_weights_shape,
const std::vector<int>& token_nums_per_expert,
const std::vector<int>& token_nums_per_expert_padded,
const int token_nums_this_rank,
const int token_nums_this_rank_padded) {
int token_rows = -1; // real token row
int moe_topk = topk_ids_shape[1];
if (input_shape.size() == 3) {
token_rows = input_shape[0] * input_shape[1];
} else {
token_rows = input_shape[0];
}
const int expert_num = token_nums_per_expert.size(); // 本地专家个数
const int num_rows = token_rows;
const int hidden_size = input_shape[input_shape.size() - 1];
return {{token_nums_this_rank_padded, hidden_size}, // x
{token_nums_this_rank_padded, hidden_size / 128}, // scale
{expert_num, num_rows},
{expert_num},
{expert_num},
{token_nums_this_rank_padded},
{num_rows, expert_num},
{expert_num},
{token_nums_this_rank_padded}}; // dst_idx per expert
}
std::vector<paddle::DataType> EPMoeExpertDispatchFP8InferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& scale_dtype,
const paddle::DataType& topk_ids_dtype,
const paddle::DataType& topk_weights_dtype,
const std::vector<int>& token_nums_per_expert,
const std::vector<int>& token_nums_per_expert_padded,
const int token_nums_this_rank,
const int token_nums_this_rank_padded) {
return {input_dtype,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT64,
paddle::DataType::INT64,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
.Inputs({"input", "scale", "topk_ids", "topk_weights"})
.Inputs({"input", "scale", "topk_ids", "topk_weights", "num_experts_per_rank_tensor", "num_experts_per_rank_padded_tensor"})
.Outputs({"permute_input",
"permute_scale",
"permute_indices_per_token",
@@ -996,12 +949,4 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
"dst_indices",
"cumsum_idx_gpu",
"m_indices"})
.Attrs({
"token_nums_per_expert: std::vector<int>",
"token_nums_per_expert_padded: std::vector<int>",
"token_nums_this_rank: int",
"token_nums_this_rank_padded: int",
})
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8))
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchFP8InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchFP8InferDtype));
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));

View File

@@ -17,6 +17,7 @@
#pragma once
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/wint_type_traits.h"
#include "helper.h"
#include "moe/fused_moe_helper.h"
@@ -71,9 +72,9 @@ void FusedMoeKernel(const paddle::Tensor& input,
auto* output_data = output->data<data_t>();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, DataType_>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, uint8_t>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::uint4b_t>();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
using NvType = typename traits_::DataType;
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,

View File

@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "cutlass_extensions/wint_type_traits.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "moe/fused_moe_op.h"
@@ -53,11 +54,16 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
template <typename T, typename NvType> class MoeHelper {
public:
MoeHelper(const std::string gemm_method,
MoeGemmRunner<NvType, NvType> *fp16_moe_gemm_runner,
MoeGemmRunner<NvType, uint8_t> *int8_moe_gemm_runner,
MoeGemmRunner<NvType, cutlass::uint4b_t> *int4_moe_gemm_runner,
int layernum = 0)
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
MoeHelper(
const std::string gemm_method,
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
int layernum = 0)
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
int8_moe_gemm_runner_(int8_moe_gemm_runner),
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
@@ -254,6 +260,7 @@ public:
total_rows_before_expert_, stream);
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments ffn1_quant_args;
int8_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const uint8_t *>(ffn1_weight->data<int8_t>()),
@@ -262,8 +269,9 @@ public:
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
"none", stream);
ffn1_quant_args, "none", stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments ffn1_quant_args;
int4_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const cutlass::uint4b_t *>(
@@ -273,8 +281,9 @@ public:
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
"none", stream);
ffn1_quant_args, "none", stream);
} else {
typename Fp16Traits::Arguments ffn1_quant_args;
fp16_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const NvType *>(ffn1_weight->data<T>()), nullptr,
@@ -282,7 +291,7 @@ public:
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
"none", stream);
ffn1_quant_args, "none", stream);
}
if (moe_type == "ffn") {
@@ -295,6 +304,7 @@ public:
T *fc2_result = fc2_output_tensor.data<T>();
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments ffn2_quant_args;
int8_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const uint8_t *>(ffn2_weight->data<int8_t>()),
@@ -302,8 +312,9 @@ public:
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, stream);
num_experts, ffn2_quant_args, stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments ffn2_quant_args;
int4_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const cutlass::uint4b_t *>(
@@ -312,15 +323,16 @@ public:
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, stream);
num_experts, ffn2_quant_args, stream);
} else {
typename Fp16Traits::Arguments ffn2_quant_args;
fp16_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const NvType *>(ffn2_weight->data<T>()), nullptr,
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, stream);
num_experts, ffn2_quant_args, stream);
}
finalize_moe_routing_kernelLauncher<T>::run(
@@ -343,9 +355,9 @@ public:
private:
std::string gemm_method_;
MoeGemmRunner<NvType, NvType> *fp16_moe_gemm_runner_;
MoeGemmRunner<NvType, uint8_t> *int8_moe_gemm_runner_;
MoeGemmRunner<NvType, cutlass::uint4b_t> *int4_moe_gemm_runner_;
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner_;
int layernum_;
CubKeyValueSorter sorter_;
};

View File

@@ -17,11 +17,11 @@
#pragma once
#include "cutlass/numeric_conversion.h"
#include "moe/fused_moe_helper.h"
#include "moe/fused_moe_imp_op.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include "moe/fused_moe_imp_op.h"
#include "moe/fused_moe_helper.h"
#include "cutlass/numeric_conversion.h"
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -37,8 +37,8 @@
namespace phi {
struct GpuLaunchConfig {
dim3 block_per_grid;
dim3 thread_per_block;
dim3 block_per_grid;
dim3 thread_per_block;
};
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
@@ -57,13 +57,30 @@ inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
return config;
}
constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input)
{
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++)
{
u[i] = static_cast<Type>(input[i]);
}
return u;
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing
// the output in the softmax kernel when we extend this module to support
// expert-choice routing.
template <typename T, int TPB>
__launch_bounds__(TPB) __global__
void group_moe_softmax(const T *input, T *output, T *softmax_max_prob,
void group_moe_softmax(const T* input,
T* output,
T* softmax_max_prob,
const int64_t num_cols,
const int64_t softmax_num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
@@ -82,6 +99,7 @@ __launch_bounds__(TPB) __global__
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
@@ -133,11 +151,14 @@ __launch_bounds__(TPB) __global__
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__
void moe_top_k(const T *inputs_after_softmax, T *output, IdxT *indices,
int *source_rows, T *softmax_max_prob,
const int64_t num_experts, const int64_t k,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
T* output,
IdxT* indices,
int* source_rows,
T* softmax_max_prob,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -155,7 +176,7 @@ __launch_bounds__(TPB) __global__
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
@@ -188,9 +209,10 @@ __launch_bounds__(TPB) __global__
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__
void moe_softmax(const T *input, T *output, const int64_t num_cols,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
T* output,
const int64_t num_cols,
const int64_t num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -240,10 +262,14 @@ __launch_bounds__(TPB) __global__
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__
void moe_top_k(const T *inputs_after_softmax, const T *bias, T *output,
IdxT *indices, int *source_rows, const int64_t num_experts,
const int64_t k, const int64_t num_rows) {
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -261,14 +287,13 @@ __launch_bounds__(TPB) __global__
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
: inputs_after_softmax[idx];
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
@@ -285,9 +310,7 @@ __launch_bounds__(TPB) __global__
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
output[idx] =
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
: result_kvp.value;
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
@@ -296,11 +319,14 @@ __launch_bounds__(TPB) __global__
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__
void moe_softmax_top_k_fused(const T *input, const T *bias, T *output,
IdxT *indices, int *source_rows,
const int64_t num_experts, const int64_t k,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -313,12 +339,11 @@ __launch_bounds__(TPB) __global__
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset + threadIdx.x;
const int64_t idx = thread_row_offset+threadIdx.x;
cub::Sum sum;
float threadData =
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
@@ -335,10 +360,10 @@ __launch_bounds__(TPB) __global__
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
@@ -348,11 +373,11 @@ __launch_bounds__(TPB) __global__
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
@@ -370,8 +395,7 @@ __launch_bounds__(TPB) __global__
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
output[cur_idx] =
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
}
@@ -380,11 +404,14 @@ __launch_bounds__(TPB) __global__
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__
void moe_top_k_normed(const T *inputs_after_softmax, const T *bias,
T *output, IdxT *indices, int *source_rows,
const int64_t num_experts, const int64_t k,
const int64_t num_rows) {
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -403,18 +430,17 @@ __launch_bounds__(TPB) __global__
extern __shared__ char smem[];
T *row_outputs = reinterpret_cast<T *>(smem);
T* row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
: inputs_after_softmax[idx];
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k];
@@ -431,14 +457,11 @@ __launch_bounds__(TPB) __global__
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset +
// result_kvp.key]: result_kvp.value;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
T row_out =
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
: result_kvp.value;
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
@@ -453,10 +476,16 @@ __launch_bounds__(TPB) __global__
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
const T *input, const T *bias, T *output, IdxT *indices, int *source_rows,
const int64_t num_experts, const int64_t k, const int64_t num_rows) {
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -469,12 +498,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset + threadIdx.x;
const int64_t idx = thread_row_offset+threadIdx.x;
cub::Sum sum;
float threadData =
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
@@ -490,12 +518,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
@@ -505,15 +533,15 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
T weight_sum = static_cast<T>(0);
extern __shared__ char smem[];
T *row_outputs = reinterpret_cast<T *>(smem);
T* row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
@@ -532,8 +560,7 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
T row_out =
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
@@ -665,11 +692,19 @@ __launch_bounds__(TPB) __global__ void moe_redundant_top_k_normed(const T* input
k.
*/
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA,
int BYTES_PER_LDG, typename IdxT = int>
__launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
void topk_gating_softmax(const T *input, T *output, const int64_t num_rows,
IdxT *indices, int *source_rows, const int64_t k) {
template <typename T,
int VPT,
int NUM_EXPERTS,
int WARPS_PER_CTA,
int BYTES_PER_LDG,
typename IdxT = int>
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
void topk_gating_softmax(const T* input,
T* output,
const int64_t num_rows,
IdxT* indices,
int* source_rows,
const int64_t k) {
// We begin by enforcing compile time assertions and setting up compile time
// constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
@@ -722,19 +757,18 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows)
return;
if (thread_row >= num_rows) return;
const bool should_process_row = true;
// We finally start setting up the read pointers for each thread. First, each
// thread jumps to the start of the row it will read.
const T *thread_row_ptr = input + thread_row * ELTS_PER_ROW;
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the
// first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const T *thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
@@ -743,10 +777,10 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
// Finally, we pull in the data from global mem
cutlass::Array<T, VPT> row_chunk_input;
AccessType *row_chunk_vec_ptr =
reinterpret_cast<AccessType *>(&row_chunk_input);
const AccessType *vec_thread_read_ptr =
reinterpret_cast<const AccessType *>(thread_read_ptr);
AccessType* row_chunk_vec_ptr =
reinterpret_cast<AccessType*>(&row_chunk_input);
const AccessType* vec_thread_read_ptr =
reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
@@ -771,8 +805,9 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
// threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask,
THREADS_PER_ROW));
thread_max =
max(thread_max,
__shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
@@ -885,7 +920,8 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
namespace detail {
// Constructs some constants needed to partition the work across threads at
// compile time.
template <typename T, int EXPERTS, int BYTES_PER_LDG> struct TopkConstants {
template <typename T, int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
@@ -896,14 +932,17 @@ template <typename T, int EXPERTS, int BYTES_PER_LDG> struct TopkConstants {
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
} // namespace detail
template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int>
void topk_gating_softmax_launcher_helper(const T *input, T *output,
IdxT *indices, int *source_row,
void topk_gating_softmax_launcher_helper(const T* input,
T* output,
IdxT* indices,
int* source_row,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k, cudaStream_t stream) {
const int64_t k,
cudaStream_t stream) {
static constexpr uint64_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG =
std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
@@ -915,46 +954,52 @@ void topk_gating_softmax_launcher_helper(const T *input, T *output,
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
<<<num_blocks, block_dim, 0, stream>>>(input, output, num_rows, indices,
source_row, k);
<<<num_blocks, block_dim, 0, stream>>>(
input, output, num_rows, indices, source_row, k);
}
template <typename T, typename IdxT = int>
struct topk_gating_softmax_kernelLauncher {
static void run(const T *input, const T *gating_correction_bias, T *output,
T *softmax, IdxT *indices, int *source_row,
T *softmax_max_prob, const int64_t num_rows,
const int64_t num_experts, const int64_t k,
const bool group_moe, cudaStream_t stream,
const bool topk_only_mode = false) {
if (topk_only_mode) {
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, gating_correction_bias, output, indices, source_row,
num_experts, k, num_rows);
return;
}
static constexpr int WARPS_PER_TB = 4;
struct topk_gating_softmax_kernelLauncher{
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
static void run(const T* input,
const T* gating_correction_bias,
T* output,
T* softmax,
IdxT* indices,
int* source_row,
T* softmax_max_prob,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
const bool group_moe,
cudaStream_t stream,
const bool topk_only_mode = false) {
if (topk_only_mode) {
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
return;
}
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
if (gating_correction_bias != nullptr)
tem_num_experts = 0;
switch (tem_num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
int64_t tem_num_experts = num_experts;
if(gating_correction_bias != nullptr) tem_num_experts = 0;
switch (tem_num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
default: {
static constexpr int TPB = 256;
@@ -964,24 +1009,40 @@ struct topk_gating_softmax_kernelLauncher {
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
group_moe_softmax<T, TPB>
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
input, softmax, softmax_max_prob, group_experts,
input,
softmax,
softmax_max_prob,
group_experts,
softmax_num_rows);
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
softmax, output, indices, source_row, softmax_max_prob, num_experts,
k, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
source_row,
softmax_max_prob,
num_experts,
k,
num_rows);
} else {
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
softmax, gating_correction_bias, output, indices, source_row,
num_experts, k, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
gating_correction_bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
}
}
};
// ========================== Permutation things
// =======================================
@@ -999,13 +1060,18 @@ struct topk_gating_softmax_kernelLauncher {
// to row 0 in the original matrix. Thus, to know where to read in the source
// matrix, we simply take the modulus of the expanded index.
template <typename T, int VecSize, typename OutT = T>
template <typename T, int VecSize, typename OutT=T>
__global__ void initialize_moe_routing_kernel(
const T *unpermuted_input, OutT *permuted_output,
const int *expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token, const float *w4a8_in_scale,
int *expanded_source_row_to_expanded_dest_row, const int64_t num_rows,
const int64_t active_rows, const int64_t cols, const int64_t num_rows_k) {
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token,
const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t num_rows_k) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
@@ -1015,24 +1081,21 @@ __global__ void initialize_moe_routing_kernel(
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
if (expanded_dest_row >= num_rows_k)
return;
if (expanded_dest_row >= num_rows_k) return;
const int expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}
if (expanded_dest_row < active_rows) {
const int expert_idx = expert_idx_per_token[expanded_dest_row];
const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1;
// Duplicate and permute rows
const int source_row = expanded_source_row % num_rows;
const T *source_row_ptr = unpermuted_input + source_row * cols;
const T* source_row_ptr = unpermuted_input + source_row * cols;
OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols;
for (int tid = threadIdx.x * VecSize; tid < cols;
@@ -1061,37 +1124,56 @@ __global__ void initialize_moe_routing_kernel(
}
template <typename T, typename OutT = T>
struct initialize_moe_routing_kernelLauncher {
static void run(const T *unpermuted_input, OutT *permuted_output,
const int *expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token, const float *w4a8_in_scale,
int *expanded_source_row_to_expanded_dest_row,
const int64_t num_rows, const int64_t active_rows,
const int64_t cols, const int64_t k, cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
constexpr int max_pack_size = 16 / sizeof(T);
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
if (cols % max_pack_size == 0) {
initialize_moe_routing_kernel<T, max_pack_size, OutT>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input, permuted_output,
expanded_dest_row_to_expanded_source_row, expert_idx_per_token,
w4a8_in_scale, expanded_source_row_to_expanded_dest_row, num_rows,
k * active_rows, cols, num_rows * k);
} else {
initialize_moe_routing_kernel<T, 1, OutT>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input, permuted_output,
expanded_dest_row_to_expanded_source_row, expert_idx_per_token,
w4a8_in_scale, expanded_source_row_to_expanded_dest_row, num_rows,
k * active_rows, cols, num_rows * k);
}
struct initialize_moe_routing_kernelLauncher{
static void run(
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token,
const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t k,
cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
constexpr int max_pack_size = 16 / sizeof(T);
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
if (cols % max_pack_size == 0) {
initialize_moe_routing_kernel<T, max_pack_size>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expert_idx_per_token,
w4a8_in_scale,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
} else {
initialize_moe_routing_kernel<T, 1>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expert_idx_per_token,
w4a8_in_scale,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
}
}
};
// ============================== Infer GEMM sizes
// =================================
__device__ inline int find_total_elts_leq_target(int *sorted_indices,
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
const int64_t arr_length,
const int64_t target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
@@ -1108,10 +1190,10 @@ __device__ inline int find_total_elts_leq_target(int *sorted_indices,
return target_location + 1;
}
void compute_total_rows_before_expert(int *sorted_indices,
void compute_total_rows_before_expert(int* sorted_indices,
const int64_t total_indices,
const int64_t num_experts,
int64_t *total_rows_before_expert,
int64_t* total_rows_before_expert,
cudaStream_t stream);
// Final kernel to unpermute and scale
@@ -1119,72 +1201,117 @@ void compute_total_rows_before_expert(int *sorted_indices,
// performs the final skip connection.
template <typename T, int RESIDUAL_NUM>
__global__ void finalize_moe_routing_kernel(
const T *expanded_permuted_rows, T *reduced_unpermuted_output,
const T *bias, const float *scales,
const int *expanded_source_row_to_expanded_dest_row,
const int *expert_for_source_row, const int64_t cols, const int64_t k,
const int64_t compute_bias, const bool norm_topk_prob,
const float routed_scaling_factor, const int64_t num_rows) {
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
// const int original_row = blockIdx.x;
// const int num_rows = gridDim.x;
if (original_row >= num_rows)
return;
T *reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const float* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int64_t cols,
const int64_t k,
const int64_t compute_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int64_t num_rows) {
const int original_row = blockIdx.x;
auto const offset = original_row * cols;
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
T thread_output{0.f};
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expanded_original_row = original_row + k_idx * num_rows;
const int expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
T* reduced_row_ptr = reduced_unpermuted_output + offset;
constexpr int64_t FINALIZE_ELEM_PER_THREAD
= 128 / cutlass::sizeof_bits<T>::value;
int64_t const start_offset = threadIdx.x;
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
int64_t const num_elems_in_col = cols / FINALIZE_ELEM_PER_THREAD;
const int64_t k_offset = original_row * k + k_idx;
const float row_scale = scales[k_offset];
row_rescale = row_rescale + row_scale;
using BiasElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
using SharedOutputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
const T *expanded_permuted_rows_row_ptr =
expanded_permuted_rows + expanded_permuted_row * cols;
auto const* bias_v = reinterpret_cast<BiasElem const*>(bias);
auto const* expanded_permuted_rows_v = reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
const int expert_idx = expert_for_source_row[k_offset];
const T *bias_ptr = bias ? bias + expert_idx * cols : nullptr;
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
{
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx)
{
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
const float row_scale = scales[k_offset];
row_rescale = row_rescale + row_scale;
thread_output =
static_cast<float>(thread_output) +
row_scale * static_cast<float>(
expanded_permuted_rows_row_ptr[tid] +
bias_value *
static_cast<T>(static_cast<float>(compute_bias)));
}
auto const* expanded_permuted_rows_row_ptr
= expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
thread_output = static_cast<float>(thread_output) /
(norm_topk_prob ? row_rescale : 1.0f) *
routed_scaling_factor;
reduced_row_ptr[tid] = thread_output;
int const expert_idx = expert_for_source_row[k_offset];
auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col;
ComputeElem bias_value;
if (bias)
{
bias_value = arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
}
else
{
bias_value.fill(0);
}
ComputeElem expert_result
= arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result + bias_value);
}
for (auto& elem : thread_output)
{
elem = elem / (norm_topk_prob ? row_rescale : 1.0f) * routed_scaling_factor;
}
OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <typename T> struct finalize_moe_routing_kernelLauncher {
template <typename T>
struct finalize_moe_routing_kernelLauncher{
static void run(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const float* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int64_t num_rows,
const int64_t cols,
const int64_t k,
const int64_t compute_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
cudaStream_t stream) {
const int blocks = num_rows;
const int threads = FINALIZE_THREADS_PER_BLOCK;
static void run(const T *expanded_permuted_rows, T *reduced_unpermuted_output,
const T *bias, const float *scales,
const int *expanded_source_row_to_expanded_dest_row,
const int *expert_for_source_row, const int64_t num_rows,
const int64_t cols, const int64_t k,
const int64_t compute_bias, const bool norm_topk_prob,
const float routed_scaling_factor, cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
finalize_moe_routing_kernel<T, 1>
<<<config_final.block_per_grid, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, bias, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
cols, k, compute_bias, norm_topk_prob, routed_scaling_factor,
finalize_moe_routing_kernel<T, 1>
<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows,
reduced_unpermuted_output,
bias,
scales,
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
cols,
k,
compute_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows);
}
}
};
} // namespace phi
} // namespace phi

View File

@@ -0,0 +1,368 @@
#include "moe_wna16_marlin_utils/marlin.cuh"
#include "paddle/phi/core/enforce.h"
#include "helper.h"
namespace MARLIN_NAMESPACE_NAME {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) {
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
}
__syncthreads();
};
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr =
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
}
} else {
if (threadIdx.x < stage_size) {
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}
} else {
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
if constexpr (has_perm) {
load_perm_to_shared(k_tile_id);
}
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, MARLIN_NAMESPACE_NAME::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddle::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
PADDLE_ENFORCE(
size_k % MARLIN_NAMESPACE_NAME::tile_k_size == 0,
"size_k = ", size_k,
" is not divisible by tile_k_size = ",
MARLIN_NAMESPACE_NAME::tile_k_size);
PADDLE_ENFORCE(
size_n % MARLIN_NAMESPACE_NAME::tile_n_size == 0,
"size_n = ", size_n,
" is not divisible by tile_n_size = ",
MARLIN_NAMESPACE_NAME::tile_n_size);
PADDLE_ENFORCE(
num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
// shape checks
PADDLE_ENFORCE(
(size_k / pack_factor) == b_q_weight.dims()[0],
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.dims()[0]);
PADDLE_ENFORCE(
b_q_weight.dims()[1] == size_n,
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.dims()[1],
", expected size_n = ", size_n);
// Verify device and strides
PADDLE_ENFORCE(
b_q_weight.is_gpu(),
"b_q_weight is not on GPU");
PADDLE_ENFORCE(
b_q_weight.is_contiguous(),
"b_q_weight is not contiguous");
PADDLE_ENFORCE(
b_q_weight.dtype() == phi::DataType::INT32,
"b_q_weight type is not kInt");
PADDLE_ENFORCE(
perm.is_gpu(),
"perm is not on GPU");
PADDLE_ENFORCE(
perm.is_contiguous(),
"perm is not contiguous");
PADDLE_ENFORCE(
perm.dtype() == phi::DataType::INT32,
"perm type is not kInt");
// Alloc buffers
// const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
paddle::Tensor out = paddle::empty(
{size_k / MARLIN_NAMESPACE_NAME::tile_size, size_n * MARLIN_NAMESPACE_NAME::tile_size / pack_factor},
b_q_weight.dtype(),
b_q_weight.place());
// Detect if there is act_order
bool has_perm = perm.dims()[0] != 0;
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data());
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data());
// Get dev info
int dev = b_q_weight.place().GetDeviceId();
cudaStream_t stream = b_q_weight.stream();
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
// TORCH_CHECK(max_shared_mem > 0);
PADDLE_ENFORCE(
max_shared_mem > 0,
"max_shared_mem must be > 0. Got = ", max_shared_mem);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
// TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
// ", has_perm = ", has_perm);
PADDLE_ENFORCE(
false,
"Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}
return {out};
}
PD_BUILD_STATIC_OP(gptq_marlin_repack)
.Inputs({"b_q_weight", "perm"})
.Outputs({"out"})
.Attrs({
"size_k: int64_t",
"size_n: int64_t",
"num_bits: int64_t"
})
.SetKernelFn(PD_KERNEL(gptq_marlin_repack));

View File

@@ -211,12 +211,14 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
const int num_rows = token_rows;
const int hidden_size = input_shape[input_shape.size() - 1];
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;
return {{moe_topk * num_rows, hidden_size},
return {{permuted_rows, hidden_size},
{expert_num},
{moe_topk, num_rows},
{num_rows, moe_topk},
{num_rows, moe_topk}};
{num_rows, moe_topk},
{permuted_rows}};
}
std::vector<paddle::DataType>
@@ -225,7 +227,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
const paddle::optional<paddle::DataType> &bias_type,
const int moe_topk) {
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
paddle::DataType::FLOAT32, paddle::DataType::INT32};
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
}
/**
@@ -281,7 +283,8 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx"})
"permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))

View File

@@ -44,9 +44,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
auto place = permute_input.place();
auto stream = permute_input.stream();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, DataType_>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, uint8_t>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::uint4b_t>();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
auto w4a8_moe_gemm_runner = W4A8MoeGemmRunner<DataType_, int8_t, cutlass::uint4b_t>();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
@@ -109,6 +109,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const int64_t tune_total_rows = expanded_active_expert_rows;
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const uint8_t*>(ffn1_weight.data<int8_t>()),
@@ -123,9 +124,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
} else if (quant_method == "weight_only_int4") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
int4_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const cutlass::uint4b_t*>(
@@ -141,6 +144,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
} else if (quant_method == "w4a8") {
@@ -165,6 +169,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const NvType*>(ffn1_weight.data<data_t>()),
@@ -177,6 +182,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
}
@@ -190,6 +196,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
auto act_out = act_out_tensor.data<data_t>();
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const uint8_t*>(ffn2_weight.data<int8_t>()),
@@ -203,9 +210,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
} else if (quant_method == "weight_only_int4") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
int4_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const cutlass::uint4b_t*>(
@@ -220,6 +229,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
} else if (quant_method == "w4a8") {
data_t *ffn2_shift = nullptr;
@@ -262,6 +272,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const NvType*>(ffn2_weight.data<data_t>()),
@@ -273,6 +284,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
}
}
@@ -288,6 +300,8 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
cudaCheckError();
const auto t_type = quant_method == "w4a8" ? ffn1_scale.get().dtype() : permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
@@ -381,14 +395,14 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
/**
* @brief Mixture of Experts (MoE) Feed-Forward Network Operator
*
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (FFN1) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (FFN2) with optional quantization
*
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
*
* Inputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [total_tokens * top_k, hidden_size]
@@ -416,18 +430,18 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* - expert_idx_per_token: Optional expert indices per token (w4a8 only)
* Shape: [total_tokens]
* dtype: int64
*
*
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or ffn1_scale dtype for w4a8)
*
*
* Attributes:
* - quant_method: Quantization method to use
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
* - used_in_ep_low_latency: Whether running in low latency mode
* Affects activation function implementation
*
*
* Note:
* - w4a8 mode requires additional workspace memory allocation
* - Low latency mode uses specialized grouped SwiGLU implementation

View File

@@ -0,0 +1,377 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/numeric_conversion.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::Tensor* ffn1_bias,
const paddle::Tensor* ffn1_super_scale,
const paddle::Tensor* ffn2_super_scale,
const paddle::Tensor* ffn1_local_scale,
const paddle::Tensor* ffn1_code_scale,
const paddle::Tensor* ffn1_code_zp,
const paddle::Tensor* ffn2_local_scale,
const paddle::Tensor* ffn2_code_scale,
const paddle::Tensor* ffn2_code_zp,
paddle::Tensor fc1_out,
paddle::Tensor ffn_out,
const int64_t total_rows_in_ll_else_minus1,
const int64_t actual_total_rows,
const int64_t inter_size,
const int64_t hidden_size,
const int num_experts,
bool used_in_ep_low_latency) {
using namespace phi;
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
using WeightType = typename WeightOnlyTraits::WeightType;
typename WeightOnlyTraits::Arguments ffn1_quant_args;
typename WeightOnlyTraits::Arguments ffn2_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data<uint8_t>();
ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data<float>();
ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data<float>();
ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data<uint8_t>();
ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data<float>();
ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data<float>();
}
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
auto stream = permute_input.stream();
moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
reinterpret_cast<const WeightType*>(ffn1_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(ffn1_super_scale ? ffn1_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(ffn1_bias ? ffn1_bias->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
ffn1_quant_args,
"none",
stream);
paddle::Tensor act_out;
if (used_in_ep_low_latency) {
act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum);
} else {
act_out = paddle::experimental::swiglu(fc1_out, nullptr);
}
moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
reinterpret_cast<const WeightType*>(ffn2_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(ffn2_super_scale ? ffn2_super_scale->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
hidden_size,
inter_size / 2,
num_experts,
ffn2_quant_args,
stream);
}
template <paddle::DataType T>
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
using namespace phi;
using data_t = typename PDTraits<T>::data_t;
using NvType = typename PDTraits<T>::DataType;
auto place = permute_input.place();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(ffn1_weight.dims().size() == 3);
const int num_experts = ffn1_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
const int64_t inter_size = inter_dim * 4;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 0;
int expanded_active_expert_rows = 0;
paddle::Tensor fc1_out_tensor;
if (permute_input.dims().size() == 3) {
num_experts_ = permute_input.dims()[0];
assert(num_experts == num_experts_);
num_max_tokens_per_expert = permute_input.dims()[1];
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
fc1_out_tensor = GetEmptyTensor(
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
} else {
expanded_active_expert_rows = permute_input.dims()[0];
fc1_out_tensor = GetEmptyTensor(
{expanded_active_expert_rows, inter_size}, T, place);
}
// This is a trick.
// expanded_active_expert_rows is not needed in variable group gemm.
// but is needed in accommodating deepep low latency mode
const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1;
// When we tune the optimal configuration, we need the actual total_rows.
const int64_t actual_total_rows = expanded_active_expert_rows;
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
const_cast<paddle::Tensor*>(ffn1_bias.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_code_zp.get_ptr()),
fc1_out_tensor,
ffn_out,
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
used_in_ep_low_latency);
}
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const bool used_in_ep_low_latency) {
const auto dtype = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, dtype);
switch (dtype) {
case paddle::DataType::BFLOAT16:
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return ffn_out;
}
std::vector<paddle::Tensor> MoeExpertFFNWint2(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const bool used_in_ep_low_latency) {
return {MoeExpertFFNWint2Func(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
used_in_ep_low_latency)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_code_zp_shape,
const bool used_in_ep_low_latency) {
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &ffn1_weight_dtype,
const paddle::DataType &ffn2_weight_dtype,
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_local_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_code_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_code_zp_dtype,
const paddle::optional<paddle::DataType> &ffn2_local_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_code_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_code_zp_dtype,
const bool used_in_ep_low_latency) {
return {permute_input_dtype};
}
/**
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (FFN1) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (FFN2) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
* Inputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [total_tokens * top_k, hidden_size]
* dtype: bfloat16/float16 (or int8 for w4a8)
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - ffn1_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - ffn2_weight: Second FFN layer weights
* Shape: [num_experts, hidden_size, inter_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - ffn1_bias: Optional bias for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - ffn1_scale: Quantization scales for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - ffn2_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
*
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or ffn1_scale dtype for w4a8)
*
* Attributes:
* - used_in_ep_low_latency: Whether running in low latency mode
* Affects activation function implementation
*
* Note:
* - Low latency mode uses specialized grouped SwiGLU implementation
*/
PD_BUILD_STATIC_OP(moe_expert_ffn_wint2)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale"),
paddle::Optional("ffn1_local_scale"),
paddle::Optional("ffn1_code_scale"),
paddle::Optional("ffn1_code_zp"),
paddle::Optional("ffn2_local_scale"),
paddle::Optional("ffn2_code_scale"),
paddle::Optional("ffn2_code_zp")})
.Outputs({"output_tensor"})
.Attrs({"used_in_ep_low_latency:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertFFNWint2))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNWint2InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNWint2InferDtype));

View File

@@ -96,7 +96,10 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t> &permute_indices_per_token_shape,
const std::vector<int64_t> &top_k_indices_shape,
const paddle::optional<std::vector<int64_t>> &ffn2_bias_shape) {
return {ffn_out_shape};
const int moe_topk = top_k_indices_shape[1];
auto out_shape = ffn_out_shape;
if (out_shape[0] != -1) out_shape[0] /= moe_topk;
return {out_shape};
}
std::vector<paddle::DataType> MoeExpertReduceInferDtype(

View File

@@ -88,6 +88,7 @@ void moe_topk_select_kernel(const T* input,
k,
num_rows);
}
cudaGetLastError();
}
else {
assert(k<=TPB);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/core/enforce.h"
#include "moe/moe_wna16_marlin_utils/kernel.h"
#include "moe/moe_wna16_marlin_utils/types.h"
std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
const paddle::Tensor& a,
const paddle::optional<paddle::Tensor>& c_or_none,
const paddle::Tensor& b_q_weight,
const paddle::Tensor& b_scales,
const paddle::optional<paddle::Tensor>& global_scale_or_none,
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
const paddle::optional<paddle::Tensor>& g_idx_or_none,
const paddle::optional<paddle::Tensor>& perm_or_none,
const paddle::Tensor& workspace,
const paddle::Tensor& sorted_token_ids,
const paddle::Tensor& expert_ids,
const paddle::Tensor& num_tokens_post_padded,
const paddle::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
const std::string& b_q_type_str,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);

View File

@@ -0,0 +1,63 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "glog/logging.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/cuda_stream.h"
namespace MARLIN_NAMESPACE_NAME {
using DeviceIndex = int8_t;
using StreamId = int64_t;
class CUDAStream {
public:
CUDAStream() {}
explicit CUDAStream(const cudaStream_t& stream) : raw_stream_(stream) {}
StreamId id() const { return reinterpret_cast<StreamId>(raw_stream_); }
operator cudaStream_t() const { return raw_stream_; }
const cudaStream_t& raw_stream() const { return raw_stream_; }
private:
cudaStream_t raw_stream_;
};
/**
* Get the current CUDA stream, for the passed CUDA device, or for the
* current device if no device index is passed. The current CUDA stream
* will usually be the default CUDA stream for the device, but it may
* be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
* or 'CUDAStreamGuard'.
*/
inline CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1) {
if (device_index == -1) {
device_index = phi::backends::gpu::GetCurrentDeviceId();
}
return CUDAStream(
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream());
// LOG(FATAL) << "getCurrentCUDAStream is not implemented";
// return *(CUDAStream*)nullptr;
}
cudaStream_t GetCalcStreamFromGroup(int context_ring_id);
cudaStream_t GetCommStreamFromGroup(int context_ring_id);
} // namespace MARLIN_NAMESPACE_NAME

Some files were not shown because too many files have changed in this diff Show More