mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-25 01:20:43 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -17,15 +17,12 @@
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void GetMaxLenKernel(const int *seq_lens,
|
||||
const int *seq_lens_this_time,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_this_time_merged,
|
||||
const int *seq_lens_encoder_merged,
|
||||
const int *seq_mapping,
|
||||
const int *system_lens,
|
||||
int *max_lens,
|
||||
const int batch_size) {
|
||||
__global__ void
|
||||
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_this_time_merged,
|
||||
const int *seq_lens_encoder_merged, const int *seq_mapping,
|
||||
const int *system_lens, int *max_lens, const int batch_size) {
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
|
||||
@@ -41,43 +38,61 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
|
||||
int max_dec_len_without_system_this_thread = 0;
|
||||
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||
const int seq_len_this_time = seq_lens_this_time[i];
|
||||
max_len_this_time_this_thread = max(seq_len_this_time,
|
||||
max_len_this_time_this_thread);
|
||||
max_len_encoder_this_thread = max(seq_lens_encoder[i],
|
||||
max_len_encoder_this_thread);
|
||||
max_len_this_time_this_thread =
|
||||
max(seq_len_this_time, max_len_this_time_this_thread);
|
||||
max_len_encoder_this_thread =
|
||||
max(seq_lens_encoder[i], max_len_encoder_this_thread);
|
||||
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
|
||||
if (seq_len_this_time <= 0) continue;
|
||||
if (seq_len_this_time <= 0)
|
||||
continue;
|
||||
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
|
||||
max_len_this_thread = max(seq_lens[i] + seq_len_this_time,
|
||||
max_len_this_thread);
|
||||
max_just_dec_len_this_thread = max(max_just_dec_len_this_thread,
|
||||
max_just_dec_len_now);
|
||||
max_len_this_thread =
|
||||
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
|
||||
max_just_dec_len_this_thread =
|
||||
max(max_just_dec_len_this_thread, max_just_dec_len_now);
|
||||
if (system_lens) {
|
||||
const int real_bid = seq_mapping[i];
|
||||
const int system_len_now = system_lens[real_bid];
|
||||
max_system_len_this_thread = max(max_system_len_this_thread, system_len_now);
|
||||
max_dec_len_without_system_this_thread = max(max_dec_len_without_system_this_thread,
|
||||
max_just_dec_len_now - system_len_now);
|
||||
max_system_len_this_thread =
|
||||
max(max_system_len_this_thread, system_len_now);
|
||||
max_dec_len_without_system_this_thread =
|
||||
max(max_dec_len_without_system_this_thread,
|
||||
max_just_dec_len_now - system_len_now);
|
||||
}
|
||||
}
|
||||
if (system_lens) {
|
||||
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
|
||||
if (ori_seq_len_this_time <= 0) continue;
|
||||
const int max_just_dec_merged_len_this_time_now = seq_lens_encoder_merged[i] > 0 ?
|
||||
0 : ori_seq_len_this_time;
|
||||
max_just_dec_merged_len_this_time_this_thread = max(max_just_dec_merged_len_this_time_this_thread,
|
||||
max_just_dec_merged_len_this_time_now);
|
||||
if (ori_seq_len_this_time <= 0)
|
||||
continue;
|
||||
const int max_just_dec_merged_len_this_time_now =
|
||||
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
|
||||
max_just_dec_merged_len_this_time_this_thread =
|
||||
max(max_just_dec_merged_len_this_time_this_thread,
|
||||
max_just_dec_merged_len_this_time_now);
|
||||
}
|
||||
}
|
||||
int total_max_len_this_time = BlockReduce(temp_storage).Reduce(max_len_this_time_this_thread, MaxOp<int>());
|
||||
int total_max_len_encoder = BlockReduce(temp_storage).Reduce(max_len_encoder_this_thread, MaxOp<int>());
|
||||
int total_max_len_decoder = BlockReduce(temp_storage).Reduce(max_len_decoder_this_thread, MaxOp<int>());
|
||||
int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
|
||||
int total_just_dec = BlockReduce(temp_storage).Reduce(max_just_dec_len_this_thread, MaxOp<int>());
|
||||
int total_just_dec_merged = BlockReduce(temp_storage).Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
|
||||
int total_system_len = BlockReduce(temp_storage).Reduce(max_system_len_this_thread, MaxOp<int>());
|
||||
int total_dec_len_without_system = BlockReduce(temp_storage).Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
|
||||
int total_max_len_this_time =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_len_this_time_this_thread, MaxOp<int>());
|
||||
int total_max_len_encoder =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_len_encoder_this_thread, MaxOp<int>());
|
||||
int total_max_len_decoder =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_len_decoder_this_thread, MaxOp<int>());
|
||||
int total =
|
||||
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
|
||||
int total_just_dec = BlockReduce(temp_storage)
|
||||
.Reduce(max_just_dec_len_this_thread, MaxOp<int>());
|
||||
int total_just_dec_merged =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
|
||||
int total_system_len = BlockReduce(temp_storage)
|
||||
.Reduce(max_system_len_this_thread, MaxOp<int>());
|
||||
int total_dec_len_without_system =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
|
||||
if (tid == 0) {
|
||||
max_lens[0] = total_max_len_this_time;
|
||||
max_lens[1] = total_max_len_encoder;
|
||||
@@ -90,30 +105,22 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
|
||||
}
|
||||
}
|
||||
|
||||
void GetMaxLen(const paddle::Tensor& seq_lens_tensor,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
paddle::Tensor &max_len_tensor,
|
||||
const int batch_size) {
|
||||
void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
paddle::Tensor &max_len_tensor, const int batch_size) {
|
||||
constexpr int blockSize = 1024;
|
||||
GetMaxLenKernel<blockSize><<<1, blockSize, 0, seq_lens_encoder.stream()>>>(
|
||||
seq_lens_tensor.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_len_tensor.data<int>(),
|
||||
batch_size);
|
||||
seq_lens_tensor.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(), nullptr, nullptr, nullptr, nullptr,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
int* __restrict__ batch_ids,
|
||||
int* __restrict__ tile_ids_per_batch,
|
||||
int* __restrict__ num_blocks_x,
|
||||
const int bsz,
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
int *__restrict__ num_blocks_x, const int bsz,
|
||||
const int num_rows_per_block,
|
||||
const int group_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -124,8 +131,7 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
|
||||
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
|
||||
seq_len = 0;
|
||||
}
|
||||
const int loop_times =
|
||||
div_up(seq_len * group_size, num_rows_per_block);
|
||||
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
@@ -136,14 +142,12 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
int* __restrict__ batch_ids,
|
||||
int* __restrict__ tile_ids_per_batch,
|
||||
int* __restrict__ num_blocks_x,
|
||||
const int bsz,
|
||||
const int pad_len,
|
||||
const int num_row_per_block) {
|
||||
__global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
int *__restrict__ num_blocks_x, const int bsz,
|
||||
const int pad_len, const int num_row_per_block) {
|
||||
if (threadIdx.x == 0) {
|
||||
int gridx = 0;
|
||||
int index = 0;
|
||||
@@ -165,50 +169,46 @@ __global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
|
||||
}
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void get_max_len_kv_ernel(int* max_seq_lens_out,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_decoder,
|
||||
const int batch_size) {
|
||||
__global__ void
|
||||
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
const int *seq_lens_decoder, const int batch_size) {
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
||||
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int max_len_this_thread = 0;
|
||||
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||
if (seq_lens_decoder[i] == 0) continue;
|
||||
max_len_this_thread = max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
|
||||
if (seq_lens_decoder[i] == 0)
|
||||
continue;
|
||||
max_len_this_thread =
|
||||
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
|
||||
}
|
||||
int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
|
||||
int total =
|
||||
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
|
||||
if (tid == 0) {
|
||||
*max_seq_lens_out = total;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = cum_offsets.shape()[0];
|
||||
auto max_len_tensor =
|
||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetMaxLen(
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
max_len_tensor,
|
||||
bsz);
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor, bsz);
|
||||
|
||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time,
|
||||
// max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system
|
||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
|
||||
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
|
||||
// max_just_dec_merged_len_this_time, max_system_len,
|
||||
// max_just_dec_len_without_system
|
||||
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
auto max_len_cpu_ptr = max_len_cpu.data<int>();
|
||||
int max_len_this_time = max_len_cpu_ptr[0];
|
||||
@@ -229,67 +229,67 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
paddle::Tensor decoder_batch_ids;
|
||||
paddle::Tensor decoder_tile_ids_per_batch;
|
||||
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
max_len_kv.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
bsz
|
||||
);
|
||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(), bsz);
|
||||
|
||||
max_len_kv_cpu =
|
||||
max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
|
||||
seq_lens_decoder.data<int>(),
|
||||
// sequence_lengths->data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
kv_batch_ids.data<int>(),
|
||||
kv_tile_ids_per_batch.data<int>(),
|
||||
kv_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
block_size,
|
||||
block_size
|
||||
);
|
||||
seq_lens_decoder.data<int>(),
|
||||
// sequence_lengths->data<int>(),
|
||||
seq_lens_encoder.data<int>(), kv_batch_ids.data<int>(),
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
|
||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up(
|
||||
(max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
|
||||
nullptr,
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
encoder_batch_ids.data<int>(),
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
encoder_block_shape_q,
|
||||
group_size);
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu =
|
||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
}
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
@@ -297,24 +297,26 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu =
|
||||
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
}
|
||||
|
||||
return {encoder_batch_ids,
|
||||
@@ -331,28 +333,22 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype) {
|
||||
return {paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32};
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_this_time_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype) {
|
||||
return {
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape) {
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape) {
|
||||
std::vector<int64_t> dynamic_shape = {-1};
|
||||
|
||||
return {dynamic_shape,
|
||||
@@ -369,9 +365,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time",
|
||||
"cum_offsets"})
|
||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
@@ -382,12 +376,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
paddle::Optional("decoder_batch_ids"),
|
||||
paddle::Optional("decoder_tile_ids_per_batch"),
|
||||
paddle::Optional("decoder_num_blocks"),
|
||||
paddle::Optional("max_len_kv"),
|
||||
"set_max_lengths"})
|
||||
.Attrs({"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
paddle::Optional("max_len_kv"), "set_max_lengths"})
|
||||
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
|
||||
"group_size: int", "block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
|
||||
@@ -337,6 +337,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (deal_each_time == 64) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the deal_each_time", deal_each_time); \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
|
||||
@@ -346,6 +348,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (num_threads == 256) { \
|
||||
constexpr size_t NUM_THREADS = 256; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_threads", num_threads); \
|
||||
}
|
||||
|
||||
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
@@ -376,6 +380,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (group_size == 12) { \
|
||||
constexpr size_t GROUP_SIZE = 12; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
namespace py = pybind11;
|
||||
|
||||
// 自定义异常类,用于处理CUDA错误
|
||||
@@ -125,45 +125,40 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
const bool norm_topk_prob, const bool group_moe);
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode);
|
||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||
const bool group_moe, const bool topk_only_mode);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||
const paddle::optional<paddle::Tensor> &bias,
|
||||
const int moe_topk, const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused);
|
||||
const paddle::optional<paddle::Tensor> &bias,
|
||||
const int moe_topk, const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused);
|
||||
|
||||
std::vector<paddle::Tensor> MoERedundantTopKSelectKernel(
|
||||
const paddle::Tensor& gating_logits,
|
||||
const paddle::Tensor& expert_id_to_ep_rank_array,
|
||||
const paddle::Tensor& expert_in_rank_num_list,
|
||||
paddle::Tensor& tokens_per_expert_stats_list,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const int moe_topk,
|
||||
const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused,
|
||||
const int redundant_ep_rank_num_plus_one);
|
||||
std::vector<paddle::Tensor>
|
||||
MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||
const paddle::Tensor &expert_id_to_ep_rank_array,
|
||||
const paddle::Tensor &expert_in_rank_num_list,
|
||||
paddle::Tensor &tokens_per_expert_stats_list,
|
||||
const paddle::optional<paddle::Tensor> &bias,
|
||||
const int moe_topk, const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused,
|
||||
const int redundant_ep_rank_num_plus_one);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &topk_weights,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
|
||||
const std::vector<int> &token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string &moe_quant_type);
|
||||
const paddle::Tensor &topk_weights,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
|
||||
const std::vector<int> &token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string &moe_quant_type);
|
||||
|
||||
std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor &input, const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
|
||||
const std::vector<int> &token_nums_per_expert,
|
||||
const std::vector<int> &token_nums_per_expert_padded,
|
||||
const int token_nums_this_rank, const int token_nums_this_rank_padded);
|
||||
const paddle::Tensor &token_nums_per_expert,
|
||||
const paddle::Tensor &token_nums_per_expert_padded);
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
const int block_size);
|
||||
@@ -180,20 +175,35 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor);
|
||||
|
||||
std::vector<std::vector<int>> GetExpertTokenNum(
|
||||
const paddle::Tensor& topk_ids,
|
||||
const int num_experts);
|
||||
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
|
||||
const int num_experts);
|
||||
|
||||
paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::Tensor &permute_input,
|
||||
const paddle::Tensor &tokens_expert_prefix_sum,
|
||||
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_in_scale,
|
||||
const paddle::optional<paddle::Tensor> &expert_idx_per_token,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency);
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertReduceFunc(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
|
||||
@@ -205,19 +215,16 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
|
||||
const paddle::Tensor &seq_lens_this_time_tensor,
|
||||
const paddle::Tensor &seq_lens_decoder_tensor,
|
||||
const int rank,
|
||||
const int num_layers);
|
||||
const int rank, const int num_layers);
|
||||
|
||||
void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
|
||||
paddle::Tensor DequantInt8Func(const paddle::Tensor &input,
|
||||
const paddle::Tensor &out_scale,
|
||||
std::string dtype);
|
||||
|
||||
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
|
||||
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
||||
const bool keep_pd_step_flag);
|
||||
|
||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||
@@ -286,61 +293,121 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
|
||||
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMPermute(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& topk_idx,
|
||||
const int num_experts,
|
||||
const int max_tokens_per_expert
|
||||
);
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
|
||||
const paddle::Tensor &topk_idx,
|
||||
const int num_experts,
|
||||
const int max_tokens_per_expert);
|
||||
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden]
|
||||
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
|
||||
const paddle::Tensor& topk_idx,
|
||||
const paddle::Tensor& topk_weights
|
||||
);
|
||||
const paddle::Tensor
|
||||
&ffn_out, // [num_experts, max_tokens_per_expert, hidden]
|
||||
const paddle::Tensor &permute_indices_per_token, // [token_num, topk}]
|
||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
paddle::Tensor &token_type_ids,
|
||||
paddle::Tensor &text_index,
|
||||
paddle::Tensor &image_index, const bool is_scatter);
|
||||
|
||||
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts);
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
|
||||
const paddle::Tensor& a,
|
||||
const paddle::optional<paddle::Tensor>& c_or_none,
|
||||
const paddle::Tensor& b_q_weight,
|
||||
const paddle::Tensor& b_scales,
|
||||
const paddle::optional<paddle::Tensor>& global_scale_or_none,
|
||||
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
|
||||
const paddle::optional<paddle::Tensor>& g_idx_or_none,
|
||||
const paddle::optional<paddle::Tensor>& perm_or_none,
|
||||
const paddle::Tensor& workspace,
|
||||
const paddle::Tensor& sorted_token_ids,
|
||||
const paddle::Tensor& expert_ids,
|
||||
const paddle::Tensor& num_tokens_post_padded,
|
||||
const paddle::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
const std::string& b_q_type_str,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b, paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void StaticScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input,
|
||||
paddle::Tensor const &scale);
|
||||
|
||||
void DynamicScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input,
|
||||
paddle::Tensor &scale);
|
||||
|
||||
void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
|
||||
paddle::Tensor const &input,
|
||||
paddle::Tensor &scales, float scale_ub);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum,
|
||||
py::arg("topk_ids"), py::arg("num_experts"),
|
||||
"get expert token num");
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
py::arg("num_experts"), "get expert token num");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_redundant_topk_select.cu
|
||||
* moe_redundant_topk_select
|
||||
*/
|
||||
m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
|
||||
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
|
||||
py::arg("expert_in_rank_num_list"),
|
||||
py::arg("tokens_per_expert_stats_list"), py::arg("bias"),
|
||||
py::arg("moe_topk"), py::arg("apply_norm_weight"),
|
||||
py::arg("enable_softmax_top_k_fused"),
|
||||
py::arg("redundant_ep_rank_num_plus_one"),
|
||||
"moe export RedundantTopKSelect function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_redundant_topk_select.cu
|
||||
* moe_redundant_topk_select
|
||||
*/
|
||||
m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
|
||||
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
|
||||
py::arg("expert_in_rank_num_list"), py::arg("tokens_per_expert_stats_list"),
|
||||
py::arg("bias"), py::arg("moe_topk"), py::arg("apply_norm_weight"),
|
||||
py::arg("enable_softmax_top_k_fused"), py::arg("redundant_ep_rank_num_plus_one"),
|
||||
"moe export RedundantTopKSelect function");
|
||||
/**
|
||||
* open_shm_and_get_meta_signal.cc
|
||||
* InitKVSignalPerQuery
|
||||
*/
|
||||
m.def("init_kv_signal_per_query", &InitKVSignalPerQuery,
|
||||
py::arg("seq_lens_encoder_tensor"),
|
||||
py::arg("seq_lens_this_time_tensor"),
|
||||
py::arg("seq_lens_decoder_tensor"), py::arg("rank"),
|
||||
py::arg("num_layers"), "init_kv_signal_per_query function");
|
||||
|
||||
/**
|
||||
* GetOutputKVSignal
|
||||
*/
|
||||
m.def("get_output_kv_signal", &GetOutputKVSignal, py::arg("x"),
|
||||
py::arg("rank_id"), py::arg("wait_flag"),
|
||||
"get_output_kv_signal function");
|
||||
|
||||
/**
|
||||
* open_shm_and_get_meta_signal.cc
|
||||
* InitKVSingnalPerQuery
|
||||
*/
|
||||
m.def("init_kv_signal_per_query", &InitKVSignalPerQuery,
|
||||
py::arg("seq_lens_encoder_tensor"), py::arg("seq_lens_this_time_tensor"),
|
||||
py::arg("seq_lens_decoder_tensor"), py::arg("rank"), py::arg("num_layers"),
|
||||
"init_kv_signal_per_query function");
|
||||
|
||||
/**
|
||||
* GetOutputKVSignal
|
||||
*/
|
||||
m.def("get_output_kv_signal", &GetOutputKVSignal,
|
||||
py::arg("x"), py::arg("rank_id"), py::arg("wait_flag"),
|
||||
"get_output_kv_signal function");
|
||||
|
||||
|
||||
|
||||
m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute");
|
||||
m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute, "MoEDeepGEMMDePermute");
|
||||
m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute");
|
||||
m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute,
|
||||
"MoEDeepGEMMDePermute");
|
||||
/**
|
||||
* alloc_cache_pinned.cc
|
||||
* cuda_host_alloc
|
||||
@@ -398,12 +465,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
|
||||
py::arg("moe_quant_type"), "ep moe export dispatch function");
|
||||
|
||||
m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8, py::arg("input"),
|
||||
py::arg("scale"), py::arg("topk_ids"), py::arg("topk_weights"),
|
||||
py::arg("token_nums_per_expert"),
|
||||
py::arg("token_nums_per_expert_padded"),
|
||||
py::arg("token_nums_this_rank"), py::arg("token_nums_this_rank_padded"),
|
||||
"ep moe export dispatch function");
|
||||
m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8);
|
||||
|
||||
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
|
||||
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
|
||||
@@ -437,6 +499,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
*/
|
||||
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_ffn_wint2.cu
|
||||
* moe_expert_ffn_wint2
|
||||
*/
|
||||
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_expert_reduce.cu
|
||||
* moe_expert_reduce
|
||||
@@ -523,4 +591,66 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("group_swiglu_with_masked", &GroupSwigluWithMasked,
|
||||
"group_swiglu_with_masked function");
|
||||
|
||||
m.def("text_image_index_out", &TextImageIndexOut,
|
||||
"text_image_index_out function");
|
||||
|
||||
m.def("text_image_gather_scatter", &TextImageGatherScatter,
|
||||
"text_image_gather_scatter function");
|
||||
|
||||
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
|
||||
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
||||
|
||||
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
||||
py::arg("a"),
|
||||
py::arg("c_or_none"),
|
||||
py::arg("b_q_weight"),
|
||||
py::arg("b_scales"),
|
||||
py::arg("global_scale_or_none"),
|
||||
py::arg("b_zeros_or_none"),
|
||||
py::arg("g_idx_or_none"),
|
||||
py::arg("perm_or_none"),
|
||||
py::arg("workspace"),
|
||||
py::arg("sorted_token_ids"),
|
||||
py::arg("expert_ids"),
|
||||
py::arg("num_tokens_post_padded"),
|
||||
py::arg("topk_weights"),
|
||||
py::arg("moe_block_size"),
|
||||
py::arg("top_k"),
|
||||
py::arg("mul_topk_weights"),
|
||||
py::arg("is_ep"),
|
||||
py::arg("b_q_type_str"),
|
||||
py::arg("size_m"),
|
||||
py::arg("size_n"),
|
||||
py::arg("size_k"),
|
||||
py::arg("is_k_full"),
|
||||
py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"),
|
||||
py::arg("is_zp_float"));
|
||||
|
||||
|
||||
/**
|
||||
* cutlass_scaled_mm.cu
|
||||
* cutlass_scaled_mm
|
||||
* cutlass_scaled_mm_azp
|
||||
*/
|
||||
m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function");
|
||||
m.def("cutlass_scaled_mm_azp", &CutlassScaledMmAzp, "cutlass_scaled_mm_azp function");
|
||||
|
||||
/**
|
||||
* quantization/common.cu
|
||||
* static_scaled_fp8_quant
|
||||
* dynamic_scaled_fp8_quant
|
||||
* dynamic_per_token_scaled_fp8_quant
|
||||
*/
|
||||
m.def("static_scaled_fp8_quant", &StaticScaledFp8Quant, "static_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scale"));
|
||||
|
||||
m.def("dynamic_scaled_fp8_quant", &DynamicScaledFp8Quant,
|
||||
"dynamic_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scale"));
|
||||
|
||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||
"dynamic_per_token_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||
}
|
||||
250
custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h
Normal file
250
custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h
Normal file
@@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Architecture-specific operators on memory added for SM80
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always,
|
||||
bool GlobalToShared = true>
|
||||
struct copy;
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
|
||||
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always,
|
||||
bool GlobalToShared = true>
|
||||
struct copy_zfill;
|
||||
|
||||
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
template <int N, bool GlobalToShared = true>
|
||||
struct copy_wait;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Always, true> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Always, false> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Always, true> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Always, false> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Global, true> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Global, false> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Global, true> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Global, false> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
template <bool GlobalToShared>
|
||||
CUTLASS_DEVICE
|
||||
void copy_fence() {}
|
||||
|
||||
template <>
|
||||
CUTLASS_DEVICE
|
||||
void copy_fence<true>() {
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <int N>
|
||||
struct copy_wait<N, false> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
copy_wait() {}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <int N>
|
||||
struct copy_wait<N, true> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
copy_wait() { cp_async_wait<N>(); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,460 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_row_array = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
||||
int group, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, group(group)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
int group;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
l,
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_col_array = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
int group,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
group(group),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
int group;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
l,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,500 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either
|
||||
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||
// all cases of per-tensor or per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graph
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->row_broadcast) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrZeroBroadcast {
|
||||
|
||||
// This struct has been modified to remove null_default (because it's always 0)
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->ptr_row != nullptr) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are broadcasting 0
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gCol,
|
||||
RTensor&& tC_rCol,
|
||||
CTensor&& tC_cCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
||||
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
||||
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gCol;
|
||||
RTensor tC_rCol;
|
||||
CTensor tC_cCol;
|
||||
Params const* params_ptr;
|
||||
int m;
|
||||
|
||||
// This function is modified from VisitorColBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rCol);
|
||||
|
||||
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||
}
|
||||
|
||||
if (params_ptr->col_broadcast) {
|
||||
// In this case we are loading from a column vector and broadcasting
|
||||
copy_if(pred, tC_gCol, tC_rCol);
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
auto dst_v = filter(tC_rCol);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(dst_v); ++i) {
|
||||
if (pred(i)) {
|
||||
dst_v(i) = *(params_ptr->ptr_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
||||
return frg_col;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mCol = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_col),
|
||||
problem_shape,
|
||||
params_ptr->dCol);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
||||
Tensor tC_gCol = group_modes<1,4>(
|
||||
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tC_cCol = group_modes<1,4>(
|
||||
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gCol), decltype(tC_rCol),
|
||||
decltype(tC_cCol), ProblemShape>(
|
||||
cute::move(tC_gCol),
|
||||
cute::move(tC_rCol),
|
||||
cute::move(tC_cCol),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
|
||||
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace fastdeploy::c2x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::Tensor const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr = static_cast<T *>(const_cast<void *>(
|
||||
tensor.data()));
|
||||
if constexpr (std::is_same_v<Descriptor,
|
||||
ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor,
|
||||
RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
}
|
||||
else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr =
|
||||
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
paddle._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_azp_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_acc_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace fastdeploy::c2x
|
||||
@@ -0,0 +1,453 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace fastdeploy::c3x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T> struct identity {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T lhs) const { return lhs; }
|
||||
};
|
||||
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct TrivialEpilogue {
|
||||
private:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
using Compute = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
template <typename... Args> static ArgumentType prepare_args(Args... args) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::Tensor const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr = static_cast<T *>(const_cast<void *>(tensor.data()));
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr =
|
||||
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(const T *const *data_ptr, bool do_broadcast) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
||||
return Arguments{data_ptr, do_broadcast};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
paddle.scaled_mm_.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be
|
||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogueBias, but the
|
||||
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
||||
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueColumnBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template ColLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_azp_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_acc_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
||||
to arrays containing different scales used in group gemm. The number of
|
||||
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
||||
group size.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueArray
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
static ArgumentType prepare_args(float const *const *a_scales_ptr,
|
||||
float const *const *b_scales_ptr,
|
||||
bool a_col_broadcast, bool b_row_broadcast) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||
a_scales_ptr, a_col_broadcast);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||
b_scales_ptr, b_row_broadcast);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace fastdeploy::c3x
|
||||
@@ -0,0 +1,284 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem
|
||||
// capacity, or overrides with manual count.
|
||||
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK,
|
||||
bool SwapAB, int carveout_bytes>
|
||||
constexpr int compute_stage_count_or_override_gated(
|
||||
StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
// 32 bytes to account for barriers etc.
|
||||
constexpr int stage_barrier_bytes = 32;
|
||||
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
|
||||
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
|
||||
constexpr int stage_bytes = [&]() -> int {
|
||||
if constexpr (SwapAB) {
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
|
||||
8 +
|
||||
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
|
||||
stage_barrier_bytes;
|
||||
} else {
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
|
||||
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
|
||||
8 +
|
||||
stage_barrier_bytes;
|
||||
}
|
||||
}();
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
|
||||
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType,
|
||||
template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<
|
||||
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>) &&
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB,
|
||||
GmemLayoutB>()>> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
|
||||
detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm =
|
||||
(cute::is_same_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only "
|
||||
"compatible with FP8 FastAccum version right now\n");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>,
|
||||
tfloat32_t, ElementA>;
|
||||
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>,
|
||||
tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA =
|
||||
detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<MmaElementA, MmaElementB, ElementAccumulator,
|
||||
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
detail::compute_stage_count_or_override_gated<
|
||||
detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
|
||||
TileShape_MNK, SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>,
|
||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||
cute::conditional_t<
|
||||
IsFP8Input,
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<
|
||||
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
|
||||
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
|
||||
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
|
||||
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType,
|
||||
template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<
|
||||
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
|
||||
detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(
|
||||
detail::is_input_fp8<ElementA, ElementB>(),
|
||||
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(
|
||||
!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA =
|
||||
detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm =
|
||||
(cute::is_same_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
detail::compute_stage_count_or_override_gated<
|
||||
detail::sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK,
|
||||
SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<
|
||||
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
|
||||
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
|
||||
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,60 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType,
|
||||
template <class /* ElementCompute */> class Activation,
|
||||
bool SwapAB = false, class Enable = void>
|
||||
struct CollectiveBuilderGated {
|
||||
static_assert(sizeof(ElementA) == 0,
|
||||
"Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,62 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA,
|
||||
class ElementB, class StrideB, class TiledMma, class GmemTiledCopyA,
|
||||
class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB,
|
||||
class TransformB,
|
||||
template <class /* ElementCompute */> class Activation,
|
||||
bool SwapAB = false>
|
||||
struct CollectiveMmaGated {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,713 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule,
|
||||
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
|
||||
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
|
||||
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_,
|
||||
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<
|
||||
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
|
||||
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
|
||||
SwapAB_> {
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy =
|
||||
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
|
||||
typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2,
|
||||
"Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for "
|
||||
"this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any
|
||||
// rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA =
|
||||
cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
|
||||
uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB =
|
||||
cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
|
||||
uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
using InternalElementAux =
|
||||
cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA,
|
||||
cute::cosize_v<SmemLayoutA>>
|
||||
smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB,
|
||||
cute::cosize_v<SmemLayoutB>>
|
||||
smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const *ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const *ptr_B0;
|
||||
ElementB const *ptr_B1;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const *>(nullptr),
|
||||
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const *>(nullptr),
|
||||
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const &problem_shape,
|
||||
Arguments const &args, void *workspace) {
|
||||
(void)workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
|
||||
// only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<InternalElementA const *>(args.ptr_A);
|
||||
auto ptr_B0 = reinterpret_cast<InternalElementB const *>(args.ptr_B0);
|
||||
|
||||
Tensor tensor_a =
|
||||
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b =
|
||||
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementA const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(
|
||||
ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
|
||||
args.scale_d1};
|
||||
} else {
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementB const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(
|
||||
ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
|
||||
args.scale_d1};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const &problem_shape,
|
||||
[[maybe_unused]] Arguments const &args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
|
||||
cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
|
||||
cute::make_shape(N, K, L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
|
||||
"minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes =
|
||||
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
|
||||
8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
|
||||
/// performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const &mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the
|
||||
/// contract Returned tuple must contain at least two elements, with the first
|
||||
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
|
||||
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
|
||||
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
|
||||
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) The rest of the
|
||||
/// tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
|
||||
Params const &mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain
|
||||
// mapping Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
} else {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
|
||||
class BlockCoord>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const &mainloop_params, MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
|
||||
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage &shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x =
|
||||
get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
|
||||
block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a =
|
||||
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b =
|
||||
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux =
|
||||
SwapAB
|
||||
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(
|
||||
cluster_local_block_id.x);
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
|
||||
: gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
|
||||
n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(
|
||||
m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
} else {
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType *tma_barrier =
|
||||
pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
|
||||
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
|
||||
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
|
||||
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
|
||||
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
|
||||
TensorStorage &shared_tensors, Params const &mainloop_params) {
|
||||
static_assert(is_rmem<FrgTensorC>::value,
|
||||
"C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutAux{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.partition_A(sAux);
|
||||
} else {
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
} else {
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
} else {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
|
||||
size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
|
||||
--k_tile_prologue) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum1);
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum1);
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
|
||||
/// ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_release,
|
||||
int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
|
||||
// done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,724 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule,
|
||||
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
|
||||
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
|
||||
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_,
|
||||
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
|
||||
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
|
||||
SwapAB_> {
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy =
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape,
|
||||
KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
|
||||
typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2,
|
||||
"Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for "
|
||||
"this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA,
|
||||
cute::cosize_v<SmemLayoutA>>
|
||||
smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB,
|
||||
cute::cosize_v<SmemLayoutB>>
|
||||
smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const *ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const *ptr_B0;
|
||||
ElementB const *ptr_B1;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const *>(nullptr),
|
||||
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, 0),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const *>(nullptr),
|
||||
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, 0),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const &problem_shape,
|
||||
Arguments const &args, void *workspace) {
|
||||
(void)workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
|
||||
// only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const *>(args.ptr_A);
|
||||
auto ptr_B0 = reinterpret_cast<ElementB const *>(args.ptr_B0);
|
||||
|
||||
Tensor tensor_a =
|
||||
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b =
|
||||
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
if constexpr (SwapAB) {
|
||||
auto ptr_Aux = reinterpret_cast<ElementA const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(
|
||||
ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux,
|
||||
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
} else {
|
||||
auto ptr_Aux = reinterpret_cast<ElementB const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(
|
||||
ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux,
|
||||
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const &problem_shape,
|
||||
[[maybe_unused]] Arguments const &args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
|
||||
cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
|
||||
cute::make_shape(N, K, L), StrideB{});
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop
|
||||
* iteration would issue 4 MMA instructions. */
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
|
||||
"minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes =
|
||||
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
|
||||
8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
|
||||
/// performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const &mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the
|
||||
/// contract Returned tuple must contain at least two elements, with the first
|
||||
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
|
||||
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
|
||||
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
|
||||
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
|
||||
Params const &mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain
|
||||
// mapping Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
} else {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
|
||||
class BlockCoord>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const &mainloop_params, MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
|
||||
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage &shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
|
||||
block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a =
|
||||
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b =
|
||||
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux =
|
||||
SwapAB
|
||||
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(
|
||||
cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
|
||||
: gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
|
||||
n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(
|
||||
m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
} else {
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType *tma_barrier =
|
||||
pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
|
||||
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
|
||||
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
|
||||
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
|
||||
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
|
||||
TensorStorage &shared_tensors, Params const &mainloop_params) {
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value,
|
||||
"C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.partition_A(sAux);
|
||||
} else {
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
} else {
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
} else {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
|
||||
size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8Accumulation accumulation0(
|
||||
accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
GmmaFP8Accumulation accumulation1(
|
||||
accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
|
||||
--k_tile_prologue) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation0.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
if (accumulation0.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
|
||||
/// ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
|
||||
// done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation0.promote_residue_if_needed();
|
||||
accumulation1.promote_residue_if_needed();
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_release,
|
||||
int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
|
||||
// done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,71 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
* Stateless universal device GEMM kernel type that treats GEMM as
|
||||
* a composition of a collective mainloop and a collective epilogue.
|
||||
*
|
||||
* Supports both the 2.x and 3.x APIs based on whether the first type is
|
||||
* a cute::tuple<> or not.
|
||||
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
|
||||
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
|
||||
*
|
||||
* In the following declaration, the name preceding the 'Or' refers to
|
||||
* 3.x API type argument order, and the name succeeding the 'Or' refers to
|
||||
* 2.x API type argument order. Template arguments without two names
|
||||
* belong to the 3.x API only.
|
||||
**/
|
||||
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
|
||||
class CollectiveMainloopOrEpilogue_,
|
||||
class CollectiveEpilogueOrThreadblockSwizzle_,
|
||||
class TileScheduler_ = void, class Enable = void>
|
||||
class GemmUniversalGated;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -130,6 +130,15 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::RowMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<
|
||||
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedCooperative,
|
||||
typename CollectiveMainloop_::
|
||||
DispatchPolicy::Schedule> &&
|
||||
CollectiveMainloop_::isGated>> {
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or
|
||||
cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
|
||||
ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups =
|
||||
CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock =
|
||||
CUTE_STATIC_V(size(TiledMma{})) +
|
||||
(NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load
|
||||
// threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
using MainloopPipelineStorage =
|
||||
typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage =
|
||||
typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
void *workspace{nullptr};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the
|
||||
// aliased type.
|
||||
static Params to_underlying_arguments(Arguments const &args,
|
||||
void *workspace) {
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments "
|
||||
"KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
"to_underlying_arguments(): Setting persistent grid SM count to "
|
||||
<< sm_count);
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void *scheduler_workspace = workspace_ptr;
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *mainloop_workspace = nullptr;
|
||||
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler.
|
||||
// Therefore it will be used in separate reduction scheme for streamk case,
|
||||
// NumEpilogueSubTiles default value is 1, which means subtile will not be
|
||||
// used, therefore separate reduction will not be enabled.
|
||||
constexpr uint32_t NumEpilogueSubTiles =
|
||||
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
|
||||
args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
|
||||
|
||||
return {args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(
|
||||
args.problem_shape, args.epilogue, epilogue_workspace),
|
||||
hw_info,
|
||||
scheduler,
|
||||
workspace};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const &args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched &&
|
||||
cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
|
||||
"meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &=
|
||||
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &=
|
||||
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
size_t workspace_size = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles =
|
||||
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
workspace_size +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
|
||||
args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status
|
||||
initialize_workspace(Arguments const &args, void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles =
|
||||
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream,
|
||||
args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
|
||||
stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const ¶ms) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread
|
||||
// blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order =
|
||||
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
|
||||
TileShape{}, ClusterShape{},
|
||||
params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, char *smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting "
|
||||
"sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(
|
||||
size(TiledMma{}) == 256,
|
||||
"Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or "
|
||||
"equal to 128 along the M-dimension.");
|
||||
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the
|
||||
* same tile */
|
||||
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
|
||||
enum class ProducerWarpRole {
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage &shared_storage =
|
||||
*reinterpret_cast<SharedStorage *>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
int mma_thread_idx = thread_idx % size(TiledMma{});
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = size(TiledMma{});
|
||||
mainloop_pipeline_params.transaction_bytes =
|
||||
CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
|
||||
mainloop_pipeline_params,
|
||||
ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Epilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
|
||||
epi_load_pipeline_params.transaction_bytes =
|
||||
CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
|
||||
epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id =
|
||||
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
|
||||
params_load_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
|
||||
// scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = []() {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
} else {
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only
|
||||
// rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread
|
||||
// block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and
|
||||
// compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors
|
||||
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
|
||||
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
|
||||
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs =
|
||||
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(
|
||||
cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid()) {
|
||||
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Get the number of K tiles to compute for this work as well as the
|
||||
// starting K tile offset of the work.
|
||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
|
||||
work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
auto work_k_tile_start =
|
||||
TileScheduler::get_work_k_tile_start(work_tile_info);
|
||||
auto k_tile_iter = cute::make_coord_iterator(
|
||||
idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(
|
||||
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx,
|
||||
block_rank_in_cluster, shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive) {
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline,
|
||||
mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
|
||||
collective_epilogue.is_producer_load_needed()) {
|
||||
while (work_tile_info.is_valid()) {
|
||||
if (!TileScheduler::requires_separate_reduction(params.scheduler)) {
|
||||
load_order_barrier.wait();
|
||||
}
|
||||
if (TileScheduler::compute_epilogue(work_tile_info,
|
||||
params.scheduler)) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline, epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
|
||||
shared_storage.tensors.epilogue,
|
||||
work_tile_info.reduction_subtile_idx());
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline,
|
||||
epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
// Do we potentially issue tail arrives for TMA stores, if epilogue load
|
||||
// is waiting for it
|
||||
bool do_store_tail = false;
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
|
||||
work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
//
|
||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||
auto accumulators0 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
auto accumulators1 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
|
||||
accumulators1, work_k_tile_count, mma_thread_idx,
|
||||
shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Make sure the math instructions are done and free buffers before
|
||||
// entering the epilogue
|
||||
collective_mainloop.mma_tail(mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
work_k_tile_count);
|
||||
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx =
|
||||
canonical_warp_group_idx() - NumLoadWarpGroups;
|
||||
|
||||
// Perform reduction across splits, if needed
|
||||
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0,
|
||||
NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1,
|
||||
NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++) {
|
||||
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
|
||||
(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
|
||||
work_tile_info.reduction_subtile_idx());
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
|
||||
do_store_tail = true;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
if (do_store_tail) {
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
|
||||
epi_store_pipe_producer_state);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
// Kernel helper function to get next work unit
|
||||
CUTLASS_DEVICE
|
||||
typename TileScheduler::WorkTileInfo
|
||||
fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info,
|
||||
TileScheduler &scheduler) const {
|
||||
// Check whether we should continue on with the current work unit. If this
|
||||
// is the case, the work unit will have been updated in
|
||||
// continue_current_work to reflect the new tile to be computed.
|
||||
if (scheduler.continue_current_work(work_tile_info)) {
|
||||
return work_tile_info;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
return scheduler.get_current_work();
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -0,0 +1,680 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cute/util/debug.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<
|
||||
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedPingpong,
|
||||
typename CollectiveMainloop_::
|
||||
DispatchPolicy::Schedule> &&
|
||||
CollectiveMainloop_::isGated>> {
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or
|
||||
cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(
|
||||
!cute::is_same_v<TileScheduler_, StreamKScheduler>,
|
||||
"Ping-pong kernel does not currently support stream-K scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
|
||||
ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t MaxThreadsPerBlock =
|
||||
CUTE_STATIC_V(size(TiledMma{})) +
|
||||
(NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load
|
||||
// threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for
|
||||
// Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier =
|
||||
cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
using MainloopPipelineStorage =
|
||||
typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage =
|
||||
typename CollectiveEpilogue::PipelineStorage;
|
||||
using MathWarpGroupOrderBarrierStorage =
|
||||
typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the
|
||||
// aliased type.
|
||||
static Params to_underlying_arguments(Arguments const &args,
|
||||
void *workspace) {
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
(void)workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments "
|
||||
"KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
"to_underlying_arguments(): Setting persistent grid SM count to "
|
||||
<< sm_count);
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void *scheduler_workspace = workspace_ptr;
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *mainloop_workspace = nullptr;
|
||||
|
||||
return {args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(
|
||||
args.problem_shape, args.epilogue, epilogue_workspace),
|
||||
hw_info,
|
||||
TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
|
||||
args.scheduler, scheduler_workspace)};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const &args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched &&
|
||||
cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
|
||||
"meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &=
|
||||
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &=
|
||||
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
size_t workspace_size = 0;
|
||||
workspace_size +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
|
||||
args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status
|
||||
initialize_workspace(Arguments const &args, void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream,
|
||||
args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
|
||||
stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const ¶ms) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread
|
||||
// blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order =
|
||||
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
|
||||
TileShape{}, ClusterShape{},
|
||||
params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, char *smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting "
|
||||
"sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
|
||||
enum class ProducerWarpRole {
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage &shared_storage =
|
||||
*reinterpret_cast<SharedStorage *>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.transaction_bytes =
|
||||
CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
|
||||
mainloop_pipeline_params,
|
||||
ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Epilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes =
|
||||
CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
|
||||
epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id =
|
||||
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
|
||||
params_load_order_barrier);
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id =
|
||||
canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
|
||||
params_math_wg_order_barrier.group_size =
|
||||
NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(
|
||||
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
|
||||
// scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [&]() {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
} else {
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only
|
||||
// rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread
|
||||
// block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and
|
||||
// compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors
|
||||
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
|
||||
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
|
||||
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs =
|
||||
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(
|
||||
cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
||||
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
scheduler.advance_to_next_work();
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
}
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(
|
||||
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx,
|
||||
block_rank_in_cluster, shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive) {
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline,
|
||||
mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
|
||||
collective_epilogue.is_producer_load_needed()) {
|
||||
load_order_barrier.wait();
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline, epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline,
|
||||
epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
Tensor accumulators0 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
Tensor accumulators1 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Order two Math WG's MMA one after the other, helps hide Epilogue
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
|
||||
accumulators1, k_tile_count, warp_group_thread_idx,
|
||||
shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Cue for next Math WG's MMA to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Make sure the math instructions are done and free buffers before
|
||||
// entering the epilogue
|
||||
collective_mainloop.mma_tail(
|
||||
mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++) {
|
||||
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
|
||||
(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
// Order two Math WG's Epilogue one after the other
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, warp_group_thread_idx,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// TMA store pipeline wait is only visible to TMA-issuing warp, so for
|
||||
// multiple-consumer kernels we need to wait for all TMA stores to
|
||||
// complete before issuing consumer order barrier arrives to ensure next
|
||||
// math consumer doesn't overwrite smem of in-flight TMA stores of
|
||||
// current consumer.
|
||||
auto [epi_load_pipe_consumer_state_next_,
|
||||
epi_store_pipe_producer_state_next_] =
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next);
|
||||
|
||||
// Update starting load/store pipeline states for the next tile
|
||||
// state has already been incremented by 1 tile in collective calls,
|
||||
// advance once again for ping pong
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work(NumMmaWarpGroups);
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -77,6 +77,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -125,6 +126,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -148,7 +150,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
@@ -179,6 +181,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -234,6 +237,7 @@ public:
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -346,6 +350,131 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, El
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -19,13 +19,11 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -197,6 +195,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -244,6 +243,9 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -265,7 +267,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
@@ -296,6 +298,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -318,11 +321,11 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
@@ -348,6 +351,131 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount =
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA =
|
||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB =
|
||||
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
static_assert((kWarpGemmIterations % 2) == 0,
|
||||
"Inner loop iteration must be an even number.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA =
|
||||
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
// w uint8; local_scale uint8;
|
||||
constexpr static int kZippedRowsPerStages =
|
||||
Shape::kK / 4 + (Shape::kK + 127) / 128;
|
||||
|
||||
// code_scale float; code_zp float; super_scale ElementB
|
||||
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
|
||||
sizeof_bits<typename Operator::ElementB>::value / 8;
|
||||
|
||||
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
|
||||
|
||||
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for quanted B operand
|
||||
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
|
||||
|
||||
/// Buffer for unzip B operand
|
||||
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
|
||||
operand_unzip_B;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
typename Operator::ElementB *operand_unzip_B_ptr() {
|
||||
return operand_unzip_B.data();
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,807 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/arch/memory_copy_sm80.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaMultistage :
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
|
||||
// accuracy, where each mainloop iteration first accumulates into a temporary
|
||||
// set of freshly-cleared accumulators, which are subsequently added to the
|
||||
// final accumulator set.
|
||||
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
// Structure encapsulating pipeline state live from one iteration to the next
|
||||
struct PipeState {
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
/// Temporary accumulator to facilitate staged-accumulation
|
||||
FragmentC tmp_accum_;
|
||||
|
||||
/// Pair of A fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A_[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A_[2];
|
||||
|
||||
/// Pair of B fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B_[2];
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Warp-level MMA operator
|
||||
Operator warp_mma_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
uint8_t* column_wise_smem_ptr_B_;
|
||||
|
||||
uint8_t* smem_zipped_ptr_B_;
|
||||
int smem_zipped_bytes_per_stage_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
|
||||
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
|
||||
|
||||
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
|
||||
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
|
||||
}
|
||||
|
||||
/// Advance shared memory read-iterators to the next stage
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_read_stage()
|
||||
{
|
||||
++smem_read_stage_idx_;
|
||||
|
||||
if (smem_read_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
smem_read_stage_idx_ = 0;
|
||||
}
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
}
|
||||
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to the stage
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(
|
||||
IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B)
|
||||
{
|
||||
// Advance global iterators
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
//iterator_B.add_tile_offset({1, 0});
|
||||
tile_dequanter_B.AddTileOffset({1, 0});
|
||||
|
||||
// Advance shared iterators
|
||||
smem_iterator_A_.add_tile_offset({0, 1});
|
||||
//smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
|
||||
if (smem_write_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) {
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB, bool InitStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
if (InitStage) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
} else {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void prologue(
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B,
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
{
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
copy_tiles_and_advance_per_stage_A(iterator_A);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
|
||||
// Move to the next write stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
|
||||
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
typename IteratorA::AccessType zero_A;
|
||||
|
||||
zero_A.clear();
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait until we have at least one completed global fetch stage
|
||||
CUTLASS_DEVICE
|
||||
void gmem_wait()
|
||||
{
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void mac_loop_iter(
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
int stage)
|
||||
{
|
||||
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
|
||||
|
||||
// Load the next warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
// Unpack and dequant the first stage of B.
|
||||
int unpack_stage = stage - Base::kStages + 2;
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, unpack_stage);
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
|
||||
}
|
||||
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
// Execute the current warp-tile of MMA operations
|
||||
if (Detail::kStagedAccumulation) {
|
||||
warp_mma_(
|
||||
pipe_state.tmp_accum_,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_
|
||||
);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma_(
|
||||
accum,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
}
|
||||
|
||||
// Except for the last warp-tile, all warp-tiles issue their share of
|
||||
// global->shared fragment copies
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
}
|
||||
}
|
||||
|
||||
// The second-to-last warp-tile also:
|
||||
// - performs the last warp-tile's share of global->shared fragment copies
|
||||
// - moves to the next global fetch stage
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Performs the last warp-tile's share of global->shared fragment copies
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
|
||||
// Move to the next global fetch stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
advance_smem_read_stage();
|
||||
|
||||
// Disable global fetching when done with global fetch iterations
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
}
|
||||
|
||||
// The last warp-tile also converts the shared memory fragments used by
|
||||
// the first warp-tile of the next iteration, if necessary (so we can
|
||||
// immediately start issuing MMA instructions at the top of the loop )
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the specified number of threadblock mainloop iterations of matrix
|
||||
/// multiply-accumulate. Assumes prologue has been initiated.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void gemm_iters(
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
|
||||
{
|
||||
PipeState pipe_state;
|
||||
|
||||
// Unpack and dequant the first stage of B.
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
|
||||
// Load first warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
|
||||
|
||||
// Load first warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Transform, if necessary, the first warp-tile's shared memory fragments
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[0],
|
||||
pipe_state.warp_transformed_frag_B_[0],
|
||||
pipe_state.warp_loaded_frag_A_[0],
|
||||
pipe_state.warp_loaded_frag_B_[0]);
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
|
||||
int stage = Base::kStages - 1;
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
mac_loop_iter(
|
||||
pipe_state,
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
stage += 1;
|
||||
}
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// Prepares the class for another prologue.
|
||||
CUTLASS_DEVICE
|
||||
void wind_down()
|
||||
{
|
||||
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
|
||||
|
||||
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
|
||||
// just decrement one tile, but not all iterators implement --() decrement.)
|
||||
#pragma unroll
|
||||
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
smem_read_stage_idx_++;
|
||||
|
||||
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
|
||||
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
|
||||
if (smem_read_stage_idx_ > 1)
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
|
||||
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
smem_read_stage_idx_ = smem_write_stage_idx_;
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< pre-load and dequantize B to shared memory
|
||||
TileDequanterB tile_dequanter_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
// Prologue (start fetching iterations of global fragments into shared memory)
|
||||
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
|
||||
// Initialize destination accumulators with source accumulators
|
||||
accum = src_accum;
|
||||
|
||||
// Perform the MAC-iterations
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
|
||||
int Stages, int NumThreads, WintQuantMethod Method>
|
||||
struct TileDequanter {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
|
||||
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
using UnzipAndDequantFunctor =
|
||||
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
|
||||
|
||||
static constexpr bool kUseSharedMemory = true;
|
||||
|
||||
static constexpr int kRows = Rows;
|
||||
static constexpr int kColumns = Columns;
|
||||
static constexpr int kStages = Stages;
|
||||
|
||||
MmaElementT *out_smem_ptr{nullptr};
|
||||
|
||||
char *pointer{nullptr};
|
||||
int64_t ldm{0};
|
||||
cutlass::MatrixCoord tb_offset;
|
||||
cutlass::MatrixCoord extent;
|
||||
|
||||
ScaleElementT *super_scale_ptr{nullptr};
|
||||
cutlass::MatrixCoord tb_offset_scale;
|
||||
|
||||
QuantArguments quant_args;
|
||||
|
||||
int64_t block_start_rows[kStages];
|
||||
bool need_preload{true};
|
||||
UnzipAndDequantFunctor unzip_functor;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
|
||||
const cutlass::MatrixCoord &extent,
|
||||
const cutlass::MatrixCoord &tb_offset,
|
||||
ScaleElementT *super_scale_ptr,
|
||||
const cutlass::MatrixCoord &tb_offset_scale,
|
||||
const QuantArguments &quant_args)
|
||||
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
|
||||
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
|
||||
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaElementT *GetOutPtr() { return out_smem_ptr; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
|
||||
tb_offset.row() += tile_offset.row() * kRows;
|
||||
tb_offset.column() += tile_offset.column() * kColumns;
|
||||
tb_offset_scale.column() += tile_offset.column() * kColumns;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
|
||||
if (tb_offset.row() >= extent.row() ||
|
||||
tb_offset.column() >= extent.column()) {
|
||||
return;
|
||||
}
|
||||
|
||||
block_start_rows[stage % kStages] = tb_offset.row();
|
||||
|
||||
using ZippedT = typename WeightQuantTraits::WeightType;
|
||||
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
|
||||
tb_offset.column();
|
||||
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
|
||||
(tb_offset.row() / 128) * ldm +
|
||||
tb_offset_scale.column();
|
||||
const float *code_scale_ptr =
|
||||
quant_args.code_scale_ptr + tb_offset_scale.column();
|
||||
const float *code_zp_ptr =
|
||||
quant_args.code_zp_ptr + tb_offset_scale.column();
|
||||
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
|
||||
scale_ptr, &args, ldm, need_preload);
|
||||
need_preload = false;
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int64_t block_start_row = block_start_rows[stage % kStages];
|
||||
if (block_start_row >= extent.row()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,447 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename T, int N>
|
||||
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
|
||||
|
||||
template <typename T, WintQuantMethod QuantMethod, int TileRows,
|
||||
int TileColumns, int NumThreads = 128>
|
||||
struct UnzipAndDequantFunctor {
|
||||
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
using ZippedT = uint16_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kZippedGroupSize = 10;
|
||||
static constexpr int32_t kNumPackedValues = 7;
|
||||
|
||||
static constexpr int32_t kWeightMask = 0x7;
|
||||
static constexpr int32_t kLocalScaleMask = 0x1FFF;
|
||||
static constexpr int32_t kBZP = 4;
|
||||
|
||||
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
|
||||
ScaleComputeT scale) {
|
||||
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
|
||||
int32_t value = shifted_value - kBZP;
|
||||
|
||||
ScaleComputeT scaled_value = static_cast<ScaleComputeT>(value) * scale;
|
||||
return static_cast<T>(scaled_value);
|
||||
}
|
||||
|
||||
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {
|
||||
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
ScaleComputeT super_scale =
|
||||
static_cast<ScaleComputeT>(super_scale_ptr[col]);
|
||||
|
||||
#pragma unroll
|
||||
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
|
||||
// the last row in group
|
||||
int zipped_row_last = group_id * 10 + 9;
|
||||
int zipped_offset_last = zipped_row_last * in_stride + col;
|
||||
int32_t zipped_value_last =
|
||||
static_cast<int32_t>(in_ptr[zipped_offset_last]);
|
||||
|
||||
ScaleComputeT local_scale =
|
||||
static_cast<ScaleComputeT>(zipped_value_last & kLocalScaleMask);
|
||||
ScaleComputeT scale = local_scale * super_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row_in_group = 0; zipped_row_in_group < 9;
|
||||
++zipped_row_in_group) {
|
||||
int zipped_row = group_id * 10 + zipped_row_in_group;
|
||||
int zipped_offset = zipped_row * in_stride + col;
|
||||
int32_t zipped_value = static_cast<int32_t>(in_ptr[zipped_offset]);
|
||||
|
||||
int row_in_group = group_id * 64 + zipped_row_in_group * 7;
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 7; ++shift_bit_id) {
|
||||
int32_t shift_bit = shift_bits[shift_bit_id];
|
||||
T value = Compute(zipped_value, shift_bit, scale);
|
||||
out_ptr[(row_in_group + shift_bit_id) * TileColumns + col] = value;
|
||||
}
|
||||
}
|
||||
|
||||
int row_in_group_last = group_id * 64 + 63;
|
||||
T value_last = Compute(zipped_value_last, shift_bits[0], scale);
|
||||
out_ptr[row_in_group_last * TileColumns + col] = value_last;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
using ZippedT = uint8_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kPackNum = 4;
|
||||
static constexpr int32_t kWeightMask = 0x3F;
|
||||
static constexpr int32_t kLocalScaleMask = 0xF;
|
||||
static constexpr int32_t kBZP = 32;
|
||||
|
||||
// weight [16, N] uint8_t
|
||||
// local_scale [1, N] uint8_t
|
||||
// code_scale [N] float
|
||||
// code_zp [N] float
|
||||
// super_scale [N] T
|
||||
|
||||
// code_scale, code_zp and super_scale
|
||||
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
|
||||
// zipped weights and local_scale
|
||||
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
|
||||
|
||||
struct Arguments {
|
||||
uint8_t *weight_ptr;
|
||||
uint8_t *local_scale_ptr;
|
||||
float *code_scale_ptr;
|
||||
float *code_zp_ptr;
|
||||
T *super_scale_ptr;
|
||||
|
||||
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
|
||||
|
||||
__device__ explicit Arguments(uint8_t *smem_ptr) {
|
||||
SetZippedPtrs(smem_ptr);
|
||||
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
|
||||
}
|
||||
|
||||
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
|
||||
SetZippedPtrs(zipped_smem_ptr);
|
||||
SetColumnWisePtrs(column_wise_smem_ptr);
|
||||
}
|
||||
|
||||
__device__ void SetZippedPtrs(uint8_t *zipped_smem_ptr) {
|
||||
weight_ptr = zipped_smem_ptr;
|
||||
local_scale_ptr = zipped_smem_ptr + (TileRows / 4) * TileColumns;
|
||||
}
|
||||
|
||||
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
|
||||
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
|
||||
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
|
||||
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
if (need_preload) {
|
||||
if (g_super_scale_ptr) {
|
||||
args->super_scale_ptr[col] = g_super_scale_ptr[col];
|
||||
} else {
|
||||
args->super_scale_ptr[col] = static_cast<T>(1);
|
||||
}
|
||||
|
||||
args->code_scale_ptr[col] = g_code_scale_ptr[col];
|
||||
args->code_zp_ptr[col] = g_code_zp_ptr[col];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
|
||||
int local_scale_offset = ls_row_id * in_stride + col;
|
||||
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < TileRows / 4; ++zipped_row) {
|
||||
int s_zipped_offset = zipped_row * TileColumns + col;
|
||||
int g_zipped_offset = zipped_row * 4 * in_stride + col;
|
||||
|
||||
args->weight_ptr[s_zipped_offset] = g_weight_ptr[g_zipped_offset];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void LoadAsync(const uint8_t *g_weight_ptr,
|
||||
const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr,
|
||||
const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
constexpr int kBytesPerThread = 16; // 16B per thread
|
||||
|
||||
constexpr int weight_size = TileRows / 4 * TileColumns;
|
||||
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
|
||||
constexpr int code_scale_size = sizeof(float) * TileColumns;
|
||||
constexpr int code_zp_size = sizeof(float) * TileColumns;
|
||||
constexpr int super_scale_size = sizeof(T) * TileColumns;
|
||||
|
||||
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
|
||||
constexpr int total_tasks = total_size / kBytesPerThread;
|
||||
|
||||
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
|
||||
|
||||
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
|
||||
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
|
||||
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
|
||||
|
||||
static_assert(TileColumns % weight_threads == 0,
|
||||
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
|
||||
|
||||
static_assert(TileColumns % local_scale_threads == 0,
|
||||
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
|
||||
|
||||
if (tid < weight_threads) {
|
||||
constexpr int weight_per_thread_size = weight_size / weight_threads;
|
||||
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads;
|
||||
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
|
||||
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
|
||||
}
|
||||
} else if (need_preload) {
|
||||
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads;
|
||||
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
|
||||
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
|
||||
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
|
||||
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
|
||||
if (g_super_scale_ptr) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
|
||||
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
|
||||
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Compute(const Arguments &args, T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
int32_t shift_bits[4] = {9, 6, 3, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
ScaleComputeT super_scale =
|
||||
static_cast<ScaleComputeT>(args.super_scale_ptr[col]);
|
||||
ScaleComputeT code_scale =
|
||||
static_cast<ScaleComputeT>(args.code_scale_ptr[col]);
|
||||
ScaleComputeT code_zp = static_cast<ScaleComputeT>(args.code_zp_ptr[col]);
|
||||
|
||||
#pragma unroll
|
||||
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
|
||||
int local_scale_offset = (group_id / 2) * TileColumns + col;
|
||||
int32_t local_scale =
|
||||
static_cast<int32_t>(args.local_scale_ptr[local_scale_offset]);
|
||||
|
||||
ScaleComputeT zipped_value[16];
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int zipped_offset = (group_id * 16 + zipped_row) * TileColumns + col;
|
||||
zipped_value[zipped_row] =
|
||||
static_cast<ScaleComputeT>(args.weight_ptr[zipped_offset]);
|
||||
}
|
||||
|
||||
int local_scale_shift = ((block_start_row / 64 + group_id + 1) & 1) * 4;
|
||||
int32_t shifted_local_scale =
|
||||
(local_scale >> local_scale_shift) & kLocalScaleMask;
|
||||
ScaleComputeT scale =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * super_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
|
||||
static_cast<ScaleComputeT>(0.5)));
|
||||
|
||||
int row = group_id * 64 + zipped_row * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
|
||||
int32_t shift_bit = shift_bits[shift_bit_id];
|
||||
int32_t shifted_value = (decode_value >> shift_bit) & kWeightMask;
|
||||
|
||||
ScaleComputeT value =
|
||||
static_cast<ScaleComputeT>(shifted_value - kBZP);
|
||||
out_ptr[(row + shift_bit_id) * TileColumns + col] =
|
||||
static_cast<T>(scale * value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
|
||||
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
|
||||
constexpr int RowStride = NumThreads * N / TileColumns;
|
||||
constexpr int kNumIters = kNumWeightsPerThread / N;
|
||||
|
||||
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
|
||||
|
||||
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int begin_col_id = (tid * N) % TileColumns;
|
||||
int begin_row_id = (tid * N) / TileColumns;
|
||||
|
||||
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
|
||||
|
||||
UnzipArray<uint8_t, N> local_scales =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
|
||||
|
||||
UnzipArray<uint8_t, N> zipped_values[2];
|
||||
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
|
||||
zipped_values[0] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
|
||||
UnzipArray<T, N> super_scales =
|
||||
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_scales =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_zps =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
|
||||
|
||||
// special for TileRows = 64
|
||||
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
|
||||
UnzipArray<ScaleComputeT, N> scales;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t shifted_local_scale =
|
||||
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
|
||||
scales[i] =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int iter_id = 0; iter_id < kNumIters; ++iter_id) {
|
||||
int zipped_row = begin_row_id + iter_id * RowStride;
|
||||
int row = zipped_row * 4;
|
||||
|
||||
if (iter_id < kNumIters - 1) {
|
||||
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
|
||||
zipped_values[(iter_id + 1) & 1] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
}
|
||||
|
||||
UnzipArray<T, N> outs[4];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
|
||||
+ code_zps[i] + decode_value_zp));
|
||||
|
||||
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
outs[0][i] = static_cast<T>(scales[i] * value_0);
|
||||
outs[1][i] = static_cast<T>(scales[i] * value_1);
|
||||
outs[2][i] = static_cast<T>(scales[i] * value_2);
|
||||
outs[3][i] = static_cast<T>(scales[i] * value_3);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
|
||||
UnzipArray<T, N> *tmp_out_ptr = reinterpret_cast<UnzipArray<T, N> *>(
|
||||
out_ptr + (row + shift_bit_id) * TileColumns + begin_col_id);
|
||||
*tmp_out_ptr = outs[shift_bit_id];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
140
custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h
Normal file
140
custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
enum WintQuantMethod {
|
||||
kNone = 0,
|
||||
kWeightOnlyInt8 = 1,
|
||||
kWeightOnlyInt4 = 2,
|
||||
kWeightOnlyInt25 = 3,
|
||||
kWeightOnlyInt2 = 4
|
||||
};
|
||||
|
||||
// Convert CUDA data type to cutlass data type
|
||||
template <typename T> struct CutlassDataType {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
template <> struct CutlassDataType<half> {
|
||||
using Type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <> struct CutlassDataType<__nv_bfloat16> {
|
||||
using Type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
template <typename ElementT, WintQuantMethod Method> struct WintQuantTraits;
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kNone> {
|
||||
using WeightType = ElementT;
|
||||
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
|
||||
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod = WintQuantMethod::kNone;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) { return dim; }
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt8> {
|
||||
using WeightType = uint8_t;
|
||||
using MmaKernelType = uint8_t;
|
||||
using MmaWeightType = uint8_t;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt8;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) { return dim; }
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt4> {
|
||||
using WeightType = cutlass::uint4b_t;
|
||||
using MmaKernelType = cutlass::uint4b_t;
|
||||
using MmaWeightType = cutlass::uint4b_t;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt4;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) { return dim; }
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt25> {
|
||||
using WeightType = uint16_t;
|
||||
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
|
||||
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt25;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kNumPackedValues = 7;
|
||||
static constexpr int32_t kPackedSize = 10;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) {
|
||||
return dim * kPackedSize / kGroupSize;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
|
||||
using WeightType = uint8_t;
|
||||
using MmaKernelType = cutlass::uint2b_t;
|
||||
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt2;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kNumPackedValues = 4;
|
||||
static constexpr int32_t kPackedSize = 16;
|
||||
|
||||
struct Arguments {
|
||||
const uint8_t *local_scale_ptr; // quanted 4-bits
|
||||
const float *code_scale_ptr;
|
||||
const float *code_zp_ptr;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) {
|
||||
return dim * kPackedSize / kGroupSize;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
||||
@@ -16,106 +16,127 @@
|
||||
#include <mutex>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
class CutlassDtypeTraits;
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
PD_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
/**
|
||||
* A wrapper for a kernel that is used to guard against compilation on
|
||||
* architectures that will never use the kernel. The purpose of this is to
|
||||
* reduce the size of the compiled binary.
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel> struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef cutlass::half_t DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
template <paddle::DataType D> class CutlassDtypeTraits;
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef cutlass::bfloat16_t DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef cutlass::half_t DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
};
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef cutlass::bfloat16_t DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
};
|
||||
|
||||
class CutlassGemmConfigMannager {
|
||||
public:
|
||||
static CutlassGemmConfigMannager& getInstance() {
|
||||
static CutlassGemmConfigMannager instance;
|
||||
return instance;
|
||||
}
|
||||
public:
|
||||
static CutlassGemmConfigMannager &getInstance() {
|
||||
static CutlassGemmConfigMannager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete;
|
||||
CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) =
|
||||
delete;
|
||||
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
|
||||
CutlassGemmConfigMannager &
|
||||
operator=(const CutlassGemmConfigMannager &) = delete;
|
||||
|
||||
void up_date_configs(const nlohmann::json& j) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto it = j.begin(); it != j.end(); ++it) {
|
||||
json_[it.key()] = it.value();
|
||||
}
|
||||
void up_date_configs(const nlohmann::json &j) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto it = j.begin(); it != j.end(); ++it) {
|
||||
json_[it.key()] = it.value();
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) {
|
||||
if (!load_initialized_) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error(
|
||||
"cutlass gemm_best_config can not be found, please set "
|
||||
"gemm_best_config'path as "
|
||||
"FLAGS_use_cutlass_device_best_config_path, or unset "
|
||||
"FLAGS_use_cutlass_device_best_config_path to tune "
|
||||
"gemm_best_config");
|
||||
}
|
||||
json_ = readJsonFromFile(config_file_path);
|
||||
load_initialized_ = true;
|
||||
save_initialized_ = false;
|
||||
}
|
||||
return &json_;
|
||||
nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
|
||||
if (!load_initialized_) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error(
|
||||
"cutlass gemm_best_config can not be found, please set "
|
||||
"gemm_best_config'path as "
|
||||
"FLAGS_use_cutlass_device_best_config_path, or unset "
|
||||
"FLAGS_use_cutlass_device_best_config_path to tune "
|
||||
"gemm_best_config");
|
||||
}
|
||||
json_ = readJsonFromFile(config_file_path);
|
||||
load_initialized_ = true;
|
||||
save_initialized_ = false;
|
||||
}
|
||||
return &json_;
|
||||
}
|
||||
|
||||
private:
|
||||
void save_gemm_best_configs_(const std::string& config_file_path) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
std::ofstream new_file(config_file_path);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
} else {
|
||||
nlohmann::json old_json = readJsonFromFile(config_file_path);
|
||||
for (auto it = json_.begin(); it != json_.end(); ++it) {
|
||||
old_json[it.key()] = it.value();
|
||||
}
|
||||
json_ = old_json;
|
||||
std::ofstream new_file(config_file_path,
|
||||
std::ios::out | std::ios::trunc);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
file.close();
|
||||
}
|
||||
return;
|
||||
private:
|
||||
void save_gemm_best_configs_(const std::string &config_file_path) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
std::ofstream new_file(config_file_path);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
} else {
|
||||
nlohmann::json old_json = readJsonFromFile(config_file_path);
|
||||
for (auto it = json_.begin(); it != json_.end(); ++it) {
|
||||
old_json[it.key()] = it.value();
|
||||
}
|
||||
json_ = old_json;
|
||||
std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
|
||||
new_file << json_.dump(4);
|
||||
new_file.close();
|
||||
file.close();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
CutlassGemmConfigMannager()
|
||||
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
|
||||
~CutlassGemmConfigMannager() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (save_initialized_) {
|
||||
std::string config_file_path = "fp8_fuse_gemm_config.json";
|
||||
save_gemm_best_configs_(config_file_path);
|
||||
}
|
||||
save_initialized_ = true;
|
||||
load_initialized_ = false;
|
||||
json_.clear();
|
||||
CutlassGemmConfigMannager()
|
||||
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
|
||||
~CutlassGemmConfigMannager() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (save_initialized_) {
|
||||
std::string config_file_path = "fp8_fuse_gemm_config.json";
|
||||
save_gemm_best_configs_(config_file_path);
|
||||
}
|
||||
mutable std::mutex mutex_;
|
||||
nlohmann::json json_;
|
||||
bool load_initialized_;
|
||||
bool save_initialized_;
|
||||
save_initialized_ = true;
|
||||
load_initialized_ = false;
|
||||
json_.clear();
|
||||
}
|
||||
mutable std::mutex mutex_;
|
||||
nlohmann::json json_;
|
||||
bool load_initialized_;
|
||||
bool save_initialized_;
|
||||
};
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "fp8_common.h"
|
||||
#include "fuse_dual_gemm_swiglu_template.h"
|
||||
#include "fuse_dual_gemm_act_template_3x.h"
|
||||
#include "fuse_dual_gemm_geglu_template.h"
|
||||
#include "fuse_dual_gemm_swiglu_template.h"
|
||||
|
||||
bool fp8_fp8_dual_gemm_scale_bias_act(
|
||||
DualGemmEpilogueAllParams params);
|
||||
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params);
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "fp8_common.h"
|
||||
#include "fuse_gemm_gelu_template.h"
|
||||
#include "fuse_gemm_noact_template.h"
|
||||
#include "fuse_gemm_relu_template.h"
|
||||
#include "fuse_gemm_gelu_template.h"
|
||||
|
||||
#include "fuse_block_gemm_act_template_3x.h"
|
||||
#include "fuse_gemm_act_template_3x.h"
|
||||
|
||||
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params);
|
||||
|
||||
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);
|
||||
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params);
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "fp8_common.h"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp"
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp"
|
||||
|
||||
template <typename InputType, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void,
|
||||
template <class /* ElementCompute */> class Activation =
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
bool SwapAB = true>
|
||||
bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
|
||||
using namespace cute;
|
||||
using ElementA = typename std::conditional_t<
|
||||
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
|
||||
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
static constexpr int AlignmentA =
|
||||
128 /
|
||||
cutlass::sizeof_bits<
|
||||
ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = ElementA; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
static constexpr int AlignmentB =
|
||||
128 /
|
||||
cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementC = ElementA; // Element type for C matrix operands
|
||||
|
||||
using LayoutC = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor>;
|
||||
static constexpr int AlignmentC =
|
||||
128 /
|
||||
cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C matrices
|
||||
// in units of elements (up to 16 bytes)
|
||||
|
||||
// Output matrix configuration
|
||||
using ElementOutput = ElementA; // Element type for output matrix operands
|
||||
// using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output
|
||||
// matrix operands
|
||||
using LayoutOutput = cute::conditional_t<SwapAB, cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor>;
|
||||
static constexpr int AlignmentOutput =
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for compute
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = CTAShape; // Threadblock-level tile size
|
||||
using KernelSchedule = MainloopScheduleType;
|
||||
using EpilogueSchedule = EpilogueScheduleType;
|
||||
using TileScheduler = TileSchedulerType;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using FusionOperation =
|
||||
cutlass::epilogue::fusion::ScaledAcc<ElementOutput, ElementCompute>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC,
|
||||
ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule,
|
||||
FusionOperation>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilderGated<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
||||
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule, Activation, SwapAB>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
int arg_m = params.M;
|
||||
int arg_n = params.N;
|
||||
ElementA const *ptr_A = reinterpret_cast<ElementA const *>(params.A);
|
||||
ElementB const *ptr_B0 = reinterpret_cast<ElementB const *>(params.B0);
|
||||
ElementB const *ptr_B1 = reinterpret_cast<ElementB const *>(params.B1);
|
||||
if constexpr (SwapAB) {
|
||||
arg_m = params.N;
|
||||
arg_n = params.M;
|
||||
ptr_A = reinterpret_cast<ElementB const *>(params.B0);
|
||||
ptr_B0 = reinterpret_cast<ElementA const *>(params.A);
|
||||
}
|
||||
StrideA stride_A = cutlass::make_cute_packed_stride(
|
||||
StrideA{}, cute::make_shape(arg_m, params.K, params.batch_count));
|
||||
StrideB stride_B = cutlass::make_cute_packed_stride(
|
||||
StrideB{}, cute::make_shape(arg_n, params.K, params.batch_count));
|
||||
StrideC stride_C;
|
||||
StrideD stride_D = cutlass::make_cute_packed_stride(
|
||||
StrideD{}, cute::make_shape(arg_m, arg_n, params.batch_count));
|
||||
|
||||
typename Gemm::Arguments arguments = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{arg_m, arg_n, params.K, params.batch_count},
|
||||
{ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_C,
|
||||
reinterpret_cast<ElementOutput *>(params.D),
|
||||
stride_D}};
|
||||
arguments.epilogue.thread.alpha = params.scale_out;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::can_implement() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(params.place);
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
status = gemm_op(arguments, workspace->ptr(), params.stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::run() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "fp8_common.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
template <
|
||||
typename InputType,
|
||||
typename OutType,
|
||||
bool hasbias,
|
||||
template <class> typename Activation,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum,
|
||||
typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized,
|
||||
typename SM = cutlass::arch::Sm90>
|
||||
bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
|
||||
using namespace cute;
|
||||
using ElementA = typename std::conditional_t<
|
||||
std::is_same_v<InputType, phi::dtype::float8_e4m3fn>,
|
||||
cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
|
||||
using ElementB = ElementA;
|
||||
using ElementD =
|
||||
typename std::conditional_t<std::is_same_v<OutType, phi::dtype::bfloat16>,
|
||||
cutlass::bfloat16_t, cutlass::half_t>;
|
||||
using ElementC = std::conditional_t<hasbias, ElementD, void>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 16 / sizeof(ElementA);
|
||||
static constexpr int AlignmentB = 16 / sizeof(ElementB);
|
||||
static constexpr int AlignmentC = hasbias ? 16 / sizeof(ElementC) : 8;
|
||||
static constexpr int AlignmentD = 16 / sizeof(ElementD);
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
using FusionOperation =
|
||||
cutlass::epilogue::fusion::LinCombEltAct<Activation, ElementD,
|
||||
ElementCompute, ElementC,
|
||||
ElementScalar, RoundStyle>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
|
||||
AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A{params.lda, cute::Int<1>{}, params.M * params.lda};
|
||||
StrideB stride_B{params.ldb, cute::Int<1>{}, params.N * params.ldb};
|
||||
StrideC stride_C{0, cute::Int<1>{}, 0};
|
||||
StrideD stride_D{params.ldd, cute::Int<1>{}, params.ldd * params.M};
|
||||
|
||||
auto a_ptr = reinterpret_cast<ElementA *>(const_cast<void *>(params.A));
|
||||
auto b_ptr = reinterpret_cast<ElementB *>(const_cast<void *>(params.B));
|
||||
auto c_ptr = reinterpret_cast<ElementC *>(const_cast<void *>(params.bias));
|
||||
auto d_ptr = reinterpret_cast<ElementD *>(params.D);
|
||||
|
||||
ProblemShapeType problem_size =
|
||||
ProblemShapeType{params.M, params.N, params.K, params.batch_count};
|
||||
|
||||
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{a_ptr, stride_A, b_ptr, stride_B},
|
||||
{{params.scale}, // epilogue.thread
|
||||
c_ptr,
|
||||
stride_C,
|
||||
d_ptr,
|
||||
stride_D}};
|
||||
if constexpr (hasbias) {
|
||||
arguments.epilogue.thread.beta = 1.0;
|
||||
}
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Gemm::can_implement() failed. "
|
||||
<< cutlassGetStatusString(status) << std::endl;
|
||||
return false;
|
||||
}
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(params.place);
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
status = gemm_op(arguments, workspace->ptr(), params.stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Gemm::run() failed." << cutlassGetStatusString(status)
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -43,7 +43,9 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@@ -156,9 +158,6 @@ struct MoeFCGemm {
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
@@ -209,6 +208,13 @@ struct MoeFCGemm {
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
|
||||
WintQuantMethod quant_method;
|
||||
|
||||
// Extra arguments for wint2.0
|
||||
uint8_t* local_scale;
|
||||
float* code_scale;
|
||||
float* code_zp;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
@@ -230,6 +236,10 @@ struct MoeFCGemm {
|
||||
total_rows(-1),
|
||||
gemm_n(0),
|
||||
gemm_k(0),
|
||||
quant_method(WintQuantMethod::kNone),
|
||||
local_scale(nullptr),
|
||||
code_scale(nullptr),
|
||||
code_zp(nullptr),
|
||||
host_problem_sizes(nullptr) {}
|
||||
|
||||
/// Ctor
|
||||
@@ -246,6 +256,10 @@ struct MoeFCGemm {
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
WintQuantMethod quant_method,
|
||||
const uint8_t* local_scale,
|
||||
const float* code_scale,
|
||||
const float* code_zp,
|
||||
GemmCoord* host_problem_sizes = nullptr)
|
||||
: problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
@@ -259,8 +273,12 @@ struct MoeFCGemm {
|
||||
total_rows(total_rows),
|
||||
gemm_n(gemm_n),
|
||||
gemm_k(gemm_k),
|
||||
quant_method(quant_method),
|
||||
local_scale(const_cast<uint8_t*>(local_scale)),
|
||||
code_scale(const_cast<float*>(code_scale)),
|
||||
code_zp(const_cast<float*>(code_zp)),
|
||||
host_problem_sizes(nullptr) {
|
||||
if (platform::is_same<uint8_t, ElementB>::value ||
|
||||
if (quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
|
||||
platform::is_same<uint4b_t, ElementB>::value) {
|
||||
assert(weight_scales);
|
||||
}
|
||||
@@ -284,6 +302,8 @@ struct MoeFCGemm {
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
WintQuantMethod quant_method;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@@ -294,7 +314,8 @@ struct MoeFCGemm {
|
||||
ptr_B(nullptr),
|
||||
weight_scales(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr) {}
|
||||
ptr_D(nullptr),
|
||||
quant_method(WintQuantMethod::kNone) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args,
|
||||
@@ -313,7 +334,8 @@ struct MoeFCGemm {
|
||||
ptr_B(args.ptr_B),
|
||||
weight_scales(args.weight_scales),
|
||||
ptr_C(args.ptr_C),
|
||||
ptr_D(args.ptr_D) {}
|
||||
ptr_D(args.ptr_D),
|
||||
quant_method(args.quant_method) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args,
|
||||
@@ -334,6 +356,7 @@ struct MoeFCGemm {
|
||||
weight_scales = args.weight_scales;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
quant_method = args.quant_method;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -358,7 +381,7 @@ struct MoeFCGemm {
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
if (platform::is_same<uint8_t, ElementB>::value ||
|
||||
if (args.quant_method != WintQuantMethod::kNone || platform::is_same<uint8_t, ElementB>::value ||
|
||||
platform::is_same<uint4b_t, ElementB>::value) {
|
||||
if (args.weight_scales == nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
@@ -394,6 +417,7 @@ struct MoeFCGemm {
|
||||
|
||||
template <typename dummy>
|
||||
struct KernelRunner<true, dummy> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
@@ -401,12 +425,14 @@ struct MoeFCGemm {
|
||||
// These types shadow the type-level definitions and support the ability
|
||||
// to implement a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
static constexpr int kInterleave =
|
||||
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(
|
||||
@@ -435,6 +461,7 @@ struct MoeFCGemm {
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// threadblock_offset of C
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
@@ -450,6 +477,7 @@ struct MoeFCGemm {
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
|
||||
}
|
||||
|
||||
// begin address offset for A for current tile
|
||||
ElementA* ptr_A =
|
||||
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
@@ -463,14 +491,17 @@ struct MoeFCGemm {
|
||||
: gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
@@ -610,6 +641,381 @@ struct MoeFCGemm {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled
|
||||
/// for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to //
|
||||
/// NOLINT perform
|
||||
>
|
||||
struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_> {
|
||||
public:
|
||||
using Base = MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, GroupScheduleMode_>;
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = false;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = typename Base::MapArguments;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and
|
||||
// complex conjugate operation. Must interact with the 'kTransposed' notion.
|
||||
static_assert(!kTransposed, "Transpose problem not supported");
|
||||
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC =
|
||||
Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor = typename Base::ProblemVisitor;
|
||||
using Arguments = typename Base::Arguments;
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : Base::Params {
|
||||
// Extra arguments for wint2.0
|
||||
uint8_t* local_scale;
|
||||
float* code_scale;
|
||||
float* code_zp;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
int tile_count = 0) // NOLINT
|
||||
: Base::Params(args, workspace, tile_count),
|
||||
local_scale(args.local_scale),
|
||||
code_scale(args.code_scale),
|
||||
code_zp(args.code_zp) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args,
|
||||
void* workspace = nullptr,
|
||||
int tile_count = 0) {
|
||||
Base::update(args, workspace, tile_count);
|
||||
|
||||
local_scale = args.local_scale;
|
||||
code_scale = args.code_scale;
|
||||
code_zp = args.code_zp;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMoeFCGemm() {}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
if (args.quant_method != WintQuantMethod::kWeightOnlyInt2) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"Wint2xMoeFCGemm::can_implement() - only support weight_only_int2!");
|
||||
return Status::kInvalid;
|
||||
} else if (args.weight_scales == nullptr || args.local_scale == nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
// The dummy template parameter is not used and exists so that we can compile
|
||||
// this code using a standard earlier than C++17. Prior to C++17, fully
|
||||
// specialized templates HAD to exists in a namespace
|
||||
template <WintQuantMethod QuantMethod, bool B, typename dummy = void>
|
||||
struct KernelRunner {
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
};
|
||||
|
||||
template <WintQuantMethod QuantMethod, typename dummy>
|
||||
struct KernelRunner<QuantMethod, true, dummy> {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
|
||||
QuantArguments quant_args;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
|
||||
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
}
|
||||
return quant_args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability
|
||||
// to implement a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
using QuantElementB = typename WeightQuantTraits::WeightType;
|
||||
using MmaElementB = typename WeightQuantTraits::MmaWeightType;
|
||||
|
||||
static constexpr int kInterleave =
|
||||
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(
|
||||
platform::is_same<LayoutB, layout::RowMajor>::value &&
|
||||
kInterleave == 1 ||
|
||||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
|
||||
kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// LayoutB should be RowMajor
|
||||
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
// wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64.
|
||||
const int64_t quant_gemm_k = WintQuantTraits<ElementA, QuantMethod>::CaclPackedDim(gemm_k);
|
||||
int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits<QuantElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile()) {
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// threadblock_offset of C
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
0);
|
||||
|
||||
// begin address offset for weight_scale.
|
||||
ElementScale* weight_scale_ptr =
|
||||
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on
|
||||
// the transpose
|
||||
int64_t rows_to_jump = 0;
|
||||
|
||||
if (params.problem_visitor.total_rows < 0) {
|
||||
rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
} else {
|
||||
rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count);
|
||||
}
|
||||
|
||||
// begin address offset for A for current tile
|
||||
ElementA* ptr_A =
|
||||
reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
// begin address offset for B for current problem_idx, totally num_experts problems
|
||||
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
|
||||
problem_idx * bytes_per_expert_matrix; // NOLINT
|
||||
|
||||
typename LayoutB::LongIndex ldm_B =
|
||||
platform::is_same<layout::RowMajor, LayoutB>::value
|
||||
? gemm_n
|
||||
: gemm_k * kInterleave;
|
||||
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
|
||||
|
||||
MmaElementB* smem_unzip_B_ptr = nullptr;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
|
||||
}
|
||||
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
|
||||
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
|
||||
byte_ptr_B,
|
||||
ldm_B,
|
||||
extent_B,
|
||||
tb_offset_B,
|
||||
weight_scale_ptr,
|
||||
tb_offset_scale,
|
||||
quant_args);
|
||||
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(LayoutA(ldm_A),
|
||||
ptr_A,
|
||||
{problem_size.m(), problem_size.k()},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
|
||||
ptr_B,
|
||||
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
|
||||
thread_idx,
|
||||
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations =
|
||||
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the
|
||||
// previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C =
|
||||
params.ptr_C ? reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n : nullptr;
|
||||
ElementC* ptr_D =
|
||||
reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
LayoutC layout_C(0);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C,
|
||||
ptr_C,
|
||||
problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D,
|
||||
ptr_D,
|
||||
problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the
|
||||
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params,
|
||||
SharedStorage& shared_storage) { // NOLINT
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
|
||||
KernelRunner<WintQuantMethod::kWeightOnlyInt2, true>::run_kernel(params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -15,16 +15,22 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
template <typename T, /*The type used for activations/scales/compute*/
|
||||
typename WeightType /* The type for the MoE weights */>
|
||||
typename WeightQuantTraits /* The quant traits for the MoE weights */>
|
||||
class MoeGemmRunner {
|
||||
public:
|
||||
using WeightType = typename WeightQuantTraits::WeightType;
|
||||
using Arguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
MoeGemmRunner();
|
||||
|
||||
void moe_gemm_bias_act(const T* A,
|
||||
@@ -38,6 +44,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
std::string activation_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
@@ -51,6 +58,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
cudaStream_t stream);
|
||||
|
||||
private:
|
||||
@@ -65,6 +73,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr);
|
||||
@@ -81,6 +90,7 @@ class MoeGemmRunner {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const Arguments& quant_args_B,
|
||||
cudaStream_t stream);
|
||||
|
||||
private:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
|
||||
@@ -22,7 +22,8 @@
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
|
||||
#include "helper.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16,
|
||||
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
@@ -21,7 +21,9 @@
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16,
|
||||
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -22,8 +22,9 @@
|
||||
namespace phi {
|
||||
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
|
||||
template class MoeGemmRunner<
|
||||
__nv_bfloat16,
|
||||
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
|
||||
#endif
|
||||
|
||||
} // namespace phi
|
||||
|
||||
} // namespace phi
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<half, half>;
|
||||
template class MoeGemmRunner<half,
|
||||
cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kNone>>;
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
|
||||
#include "helper.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<
|
||||
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
|
||||
|
||||
} // namespace phi
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<half, cutlass::uint4b_t>;
|
||||
template class MoeGemmRunner<
|
||||
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt4>>;
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
namespace phi {
|
||||
|
||||
template class MoeGemmRunner<half, uint8_t>;
|
||||
template class MoeGemmRunner<
|
||||
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt8>>;
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
@@ -24,9 +24,10 @@
|
||||
#include <math.h>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
@@ -35,8 +36,11 @@
|
||||
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
|
||||
@@ -48,17 +52,47 @@
|
||||
#include "helper.h"
|
||||
|
||||
namespace phi {
|
||||
// ============================= Variable batched Gemm things
|
||||
// ===========================
|
||||
|
||||
template <typename MixedGemmArchTraits, cutlass::WintQuantMethod Method>
|
||||
struct CutlassLayoutB {
|
||||
using Type = typename MixedGemmArchTraits::LayoutB;
|
||||
};
|
||||
|
||||
template <typename MixedGemmArchTraits>
|
||||
struct CutlassLayoutB<MixedGemmArchTraits, cutlass::WintQuantMethod::kNone> {
|
||||
using Type = cutlass::layout::RowMajor;
|
||||
};
|
||||
|
||||
template <typename BaseGemmKernel, typename Arch, cutlass::WintQuantMethod Method>
|
||||
struct CutlassGemmKernel {
|
||||
using Type =
|
||||
cutlass::gemm::kernel::MoeFCGemm<typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
};
|
||||
|
||||
template <typename BaseGemmKernel, typename Arch>
|
||||
struct CutlassGemmKernel<BaseGemmKernel, Arch, cutlass::WintQuantMethod::kWeightOnlyInt2> {
|
||||
using Type =
|
||||
cutlass::gemm::kernel::Wint2xMoeFCGemm<typename BaseGemmKernel::Mma,
|
||||
typename BaseGemmKernel::Epilogue,
|
||||
typename BaseGemmKernel::ThreadblockSwizzle,
|
||||
Arch,
|
||||
BaseGemmKernel::kGroupScheduleMode>;
|
||||
};
|
||||
|
||||
// ======================= Variable batched Gemm things =======================
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
int Stages>
|
||||
void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -67,6 +101,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
const int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
@@ -86,44 +121,26 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
"Specialized for half, float");
|
||||
#endif
|
||||
|
||||
using WeightType = typename WeightQuantTraits::WeightType;
|
||||
|
||||
static_assert(
|
||||
cutlass::platform::is_same<T, WeightType>::value ||
|
||||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
|
||||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
||||
"");
|
||||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value ||
|
||||
cutlass::platform::is_same<WeightType, uint16_t>::value,
|
||||
"Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to
|
||||
// cutlass::half_t if necessary.
|
||||
using ElementType_ = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<T, half>::value,
|
||||
cutlass::half_t,
|
||||
T>::type;
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
using ElementType = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t,
|
||||
ElementType_>::type;
|
||||
#else
|
||||
using ElementType = ElementType_;
|
||||
#endif
|
||||
|
||||
using CutlassWeightType_ = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<WeightType, half>::value,
|
||||
cutlass::half_t,
|
||||
WeightType>::type;
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
using CutlassWeightType = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t,
|
||||
CutlassWeightType_>::type;
|
||||
#else
|
||||
using CutlassWeightType = CutlassWeightType_;
|
||||
#endif
|
||||
using ElementType = typename cutlass::CutlassDataType<T>::Type;
|
||||
using CutlassWeightType = typename cutlass::CutlassDataType<typename WeightQuantTraits::WeightType>::Type;
|
||||
using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType;
|
||||
using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType;
|
||||
|
||||
// We need separate config for each architecture since we will target
|
||||
// different tensorcore instructions. For float, we do not target TCs.
|
||||
using MixedGemmArchTraits = cutlass::gemm::kernel::
|
||||
MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
||||
MixedGemmArchTraits<ElementType, CutlassMmaKernelType, arch>;
|
||||
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
||||
|
||||
using EpilogueOp = typename Epilogue<ElementType,
|
||||
@@ -132,13 +149,13 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
EpilogueTag>::Op;
|
||||
|
||||
// Finally, set up the kernel.
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
ElementType,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessA,
|
||||
CutlassWeightType,
|
||||
typename MixedGemmArchTraits::LayoutB,
|
||||
CutlassMmaKernelType,
|
||||
typename CutlassLayoutB<MixedGemmArchTraits, WeightQuantTraits::kQuantMethod>::Type,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessB,
|
||||
ElementType,
|
||||
@@ -155,14 +172,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma,
|
||||
typename GemmKernel_::Epilogue,
|
||||
typename GemmKernel_::ThreadblockSwizzle,
|
||||
arch, // Ensure top level arch is used
|
||||
// for dispatch
|
||||
GemmKernel_::kGroupScheduleMode>;
|
||||
|
||||
using GemmKernel = typename CutlassGemmKernel<BaseGemmKernel, arch, WeightQuantTraits::kQuantMethod>::Type;
|
||||
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
if (kernel_occupancy != nullptr) {
|
||||
@@ -181,19 +191,32 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f),
|
||||
ElementAccumulator(0.f));
|
||||
|
||||
const uint8_t* local_scale_B = nullptr;
|
||||
const float* code_scale_B = nullptr;
|
||||
const float* code_zp_B = nullptr;
|
||||
if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
local_scale_B = quant_args_B.local_scale_ptr;
|
||||
code_scale_B = quant_args_B.code_scale_ptr;
|
||||
code_zp_B = quant_args_B.code_zp_ptr;
|
||||
}
|
||||
|
||||
typename GemmGrouped::Arguments args(
|
||||
num_experts,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<const ElementType*>(A),
|
||||
reinterpret_cast<const CutlassWeightType*>(B),
|
||||
reinterpret_cast<const CutlassMmaWeightType*>(B),
|
||||
reinterpret_cast<const ElementType*>(weight_scales),
|
||||
reinterpret_cast<const ElementType*>(biases),
|
||||
reinterpret_cast<ElementType*>(C),
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
gemm_n,
|
||||
gemm_k);
|
||||
gemm_k,
|
||||
WeightQuantTraits::kQuantMethod,
|
||||
local_scale_B,
|
||||
code_scale_B,
|
||||
code_zp_B);
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
@@ -222,7 +245,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
@@ -231,7 +254,7 @@ template <typename T,
|
||||
typename Enable = void>
|
||||
struct dispatch_stages {
|
||||
static void dispatch(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -240,6 +263,7 @@ struct dispatch_stages {
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
@@ -253,20 +277,20 @@ struct dispatch_stages {
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
struct dispatch_stages<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
arch,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
2> {
|
||||
static void dispatch(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -275,12 +299,13 @@ struct dispatch_stages<T,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
generic_moe_gemm_kernelLauncher<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
arch,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
@@ -295,6 +320,7 @@ struct dispatch_stages<T,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
gemm_config,
|
||||
multi_processor_count,
|
||||
stream,
|
||||
@@ -303,13 +329,13 @@ struct dispatch_stages<T,
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
int Stages>
|
||||
struct dispatch_stages<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
cutlass::arch::Sm80,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
@@ -317,7 +343,7 @@ struct dispatch_stages<T,
|
||||
Stages,
|
||||
typename std::enable_if<(Stages > 2)>::type> {
|
||||
static void dispatch(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -326,12 +352,13 @@ struct dispatch_stages<T,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
generic_moe_gemm_kernelLauncher<T,
|
||||
WeightType,
|
||||
WeightQuantTraits,
|
||||
cutlass::arch::Sm80,
|
||||
EpilogueTag,
|
||||
ThreadblockShape,
|
||||
@@ -346,6 +373,7 @@ struct dispatch_stages<T,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
gemm_config,
|
||||
multi_processor_count,
|
||||
stream,
|
||||
@@ -354,13 +382,13 @@ struct dispatch_stages<T,
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
void dispatch_gemm_config(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -369,6 +397,7 @@ void dispatch_gemm_config(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
@@ -376,7 +405,7 @@ void dispatch_gemm_config(const T* A,
|
||||
#define dispatch_stages_macro(STAGE) \
|
||||
case STAGE: \
|
||||
dispatch_stages<T, \
|
||||
WeightType, \
|
||||
WeightQuantTraits, \
|
||||
arch, \
|
||||
EpilogueTag, \
|
||||
ThreadblockShape, \
|
||||
@@ -391,6 +420,7 @@ void dispatch_gemm_config(const T* A,
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
multi_processor_count, \
|
||||
stream, \
|
||||
@@ -414,7 +444,7 @@ void dispatch_gemm_config(const T* A,
|
||||
case CutlassTileConfig:: \
|
||||
CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \
|
||||
dispatch_gemm_config<T, \
|
||||
WeightType, \
|
||||
WeightQuantTraits, \
|
||||
arch, \
|
||||
EpilogueTag, \
|
||||
cutlass::gemm::GemmShape<AA, BB, CC>, \
|
||||
@@ -425,10 +455,11 @@ void dispatch_gemm_config(const T* A,
|
||||
biases, \
|
||||
C, \
|
||||
total_rows_before_expert, \
|
||||
total_rows, \
|
||||
total_rows, \
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
multi_processor_count, \
|
||||
stream, \
|
||||
@@ -438,14 +469,14 @@ void dispatch_gemm_config(const T* A,
|
||||
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
|
||||
// This overload is only enabled when T == WeightType.
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value &&
|
||||
std::is_same<T, WeightType>::value>::type* =
|
||||
std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -454,6 +485,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
@@ -474,7 +506,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] Config is invalid for same "
|
||||
"type MoE tensorop GEMM.");
|
||||
"type MoE tensorop GEMM for FP16/BF16.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -483,14 +515,14 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
// Overload for quantize MoE GEMMs. We disable some warp configs here since they
|
||||
// will not be used and we can improve compile time
|
||||
template <typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, WeightType>::value>::type* =
|
||||
!std::is_same<T, typename WeightQuantTraits::WeightType>::value>::type* =
|
||||
nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -499,28 +531,34 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
cudaStream_t stream,
|
||||
int* occupancy = nullptr) {
|
||||
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
|
||||
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
|
||||
"already been set by heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
|
||||
"mixed type tensorop GEMM.");
|
||||
break;
|
||||
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
switch (gemm_config.tile_config) {
|
||||
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
|
||||
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
|
||||
case CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
|
||||
"already been set by heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
|
||||
"mixed type tensorop GEMM for sm70.");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
|
||||
}
|
||||
} else {
|
||||
switch (gemm_config.tile_config) {
|
||||
@@ -555,12 +593,12 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
|
||||
template <
|
||||
typename T,
|
||||
typename WeightType,
|
||||
typename WeightQuantTraits,
|
||||
typename arch,
|
||||
typename EpilogueTag,
|
||||
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -569,6 +607,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
int sm_version,
|
||||
int multi_processor_count,
|
||||
@@ -594,8 +633,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
MoeGemmRunner<T, WeightQuantTraits>::MoeGemmRunner() {
|
||||
int device{-1};
|
||||
check_cuda_error(cudaGetDevice(&device));
|
||||
sm_ = getSMVersion();
|
||||
@@ -603,11 +642,11 @@ MoeGemmRunner<T, WeightType>::MoeGemmRunner() {
|
||||
&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::dispatch_to_arch<EpilogueTag>(
|
||||
const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
@@ -616,11 +655,12 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream,
|
||||
int* occupancy) {
|
||||
#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \
|
||||
dispatch_moe_gemm_to_cutlass<T, WeightType, ARCH, EpilogueTag>( \
|
||||
dispatch_moe_gemm_to_cutlass<T, WeightQuantTraits, ARCH, EpilogueTag>( \
|
||||
A, \
|
||||
B, \
|
||||
weight_scales, \
|
||||
@@ -631,6 +671,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
gemm_n, \
|
||||
gemm_k, \
|
||||
num_experts, \
|
||||
quant_args_B, \
|
||||
gemm_config, \
|
||||
sm_, \
|
||||
multi_processor_count_, \
|
||||
@@ -648,25 +689,28 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
|
||||
const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t tune_total_rows,
|
||||
int64_t actual_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
cudaStream_t stream) {
|
||||
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
||||
static constexpr bool is_weight_only = !std::is_same<T, typename WeightQuantTraits::WeightType>::value;
|
||||
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs =
|
||||
get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true);
|
||||
|
||||
static constexpr int warm_time = 5;
|
||||
static constexpr int test_time = 10;
|
||||
auto& gemmConfigManager = GemmConfigManager::Instance();
|
||||
@@ -676,17 +720,19 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n, gemm_k, GemmType::MOEGEMM, dtype, wdtype, num_experts};
|
||||
CutlassGemmConfig chosen_config;
|
||||
auto chosen_config_optional =
|
||||
gemmConfigManager.getBestConfig(gemmId, tune_total_rows);
|
||||
gemmConfigManager.getBestConfig(gemmId, actual_total_rows);
|
||||
if (chosen_config_optional != std::nullopt) {
|
||||
chosen_config = chosen_config_optional.value();
|
||||
} else {
|
||||
size_t best_id = -1;
|
||||
float best_time = std::numeric_limits<float>::max();
|
||||
CutlassGemmConfig best_config;
|
||||
int profile_total_rows =
|
||||
std::min(gemmConfigManager.nextPowerOfTwo(tune_total_rows),
|
||||
std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows),
|
||||
gemmConfigManager.getMaxProfileM());
|
||||
bool find_one = false;
|
||||
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
||||
size_t num_candidate_configs_size = candidate_configs.size();
|
||||
for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) {
|
||||
try {
|
||||
for (int i = 0; i < warm_time; i++) {
|
||||
dispatch_to_arch<EpilogueTag>(A,
|
||||
@@ -699,6 +745,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
candidate_configs[ii],
|
||||
stream);
|
||||
}
|
||||
@@ -719,6 +766,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
candidate_configs[ii],
|
||||
stream);
|
||||
}
|
||||
@@ -728,7 +776,9 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
|
||||
check_cuda_error(cudaEventDestroy(start));
|
||||
check_cuda_error(cudaEventDestroy(stop));
|
||||
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
|
||||
if (elapsed < best_time) {
|
||||
best_id = ii;
|
||||
best_time = elapsed;
|
||||
best_config = candidate_configs[ii];
|
||||
}
|
||||
@@ -739,6 +789,7 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
}
|
||||
}
|
||||
if (find_one) {
|
||||
//std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl;
|
||||
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
|
||||
chosen_config = best_config;
|
||||
} else {
|
||||
@@ -756,23 +807,25 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
chosen_config,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm_bias_act(
|
||||
const T* A,
|
||||
const WeightType* B,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
const T* biases,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t tune_total_rows,
|
||||
int64_t actual_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
std::string activation_type,
|
||||
cudaStream_t stream) {
|
||||
if (activation_type == "none") {
|
||||
@@ -784,10 +837,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
|
||||
C,
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
tune_total_rows,
|
||||
actual_total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
stream);
|
||||
} else {
|
||||
run_gemm<EpilogueOpNoBias>(A,
|
||||
@@ -797,27 +851,30 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
|
||||
C,
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
tune_total_rows,
|
||||
actual_total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
|
||||
const WeightType* B,
|
||||
const T* weight_scales,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t tune_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
cudaStream_t stream) {
|
||||
template <typename T, typename WeightQuantTraits>
|
||||
void MoeGemmRunner<T, WeightQuantTraits>::moe_gemm(
|
||||
const T* A,
|
||||
const typename WeightQuantTraits::WeightType* B,
|
||||
const T* weight_scales,
|
||||
T* C,
|
||||
int64_t* total_rows_before_expert,
|
||||
int64_t total_rows,
|
||||
int64_t actual_total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int num_experts,
|
||||
const typename WeightQuantTraits::Arguments& quant_args_B,
|
||||
cudaStream_t stream) {
|
||||
run_gemm<EpilogueOpNoBias>(A,
|
||||
B,
|
||||
weight_scales,
|
||||
@@ -825,10 +882,11 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
|
||||
C,
|
||||
total_rows_before_expert,
|
||||
total_rows,
|
||||
tune_total_rows,
|
||||
actual_total_rows,
|
||||
gemm_n,
|
||||
gemm_k,
|
||||
num_experts,
|
||||
quant_args_B,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,102 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "helper.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_helper.h"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int>
|
||||
get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) {
|
||||
int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1];
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
phi::Place device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape,
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(device);
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
auto stream = paddle::GetCurrentCUDAStream(device)->raw_stream();
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace->ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = StrideC;
|
||||
using StrideAux = StrideC;
|
||||
|
||||
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
StrideB b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
StrideC c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
StrideD d_stride =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
StrideAux aux_stride = d_stride;
|
||||
|
||||
auto a_ptr = static_cast<ElementAB *>(const_cast<void *>(a.data()));
|
||||
auto b_ptr = static_cast<ElementAB *>(const_cast<void *>(b.data()));
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD *>(const_cast<void *>(out.data()));
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, d_stride};
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.place(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
} // namespace fastdeploy::c3x
|
||||
149
custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh
Normal file
149
custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh
Normal file
@@ -0,0 +1,149 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_helper.h"
|
||||
#include "helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
Epilogues defined in,
|
||||
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
|
||||
must contain a public type named EVTCompute of type Sm90EVT, as well as a
|
||||
static prepare_args function that constructs an EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
|
||||
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm100 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,27 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::optional<paddle::Tensor> const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
*azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,34 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func>
|
||||
void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b, paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias,
|
||||
Fp8Func fp8_func, Int8Func int8_func) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1];
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) {
|
||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
PD_CHECK(false, "Int8 not supported for this architecture");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"No kernel for this combination of input dtypes is implemented."));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,28 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,125 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (128, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.dims()[0];
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...epilogue_args) {
|
||||
PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,29 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,168 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_default {
|
||||
// For M > 128 and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.dims()[0];
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out,
|
||||
paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
EpilogueArgs &&...epilogue_args) {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
200
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu
Normal file
200
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu
Normal file
@@ -0,0 +1,200 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
|
||||
|
||||
#include "helper.h"
|
||||
#include <stddef.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace fastdeploy;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.dtype() == paddle::DataType::INT8) {
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
if (out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
if (bias) {
|
||||
PD_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
223
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cuh
Normal file
223
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cuh
Normal file
@@ -0,0 +1,223 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
|
||||
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include "helper.h"
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "cutlass_helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
Epilogues defined in,
|
||||
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm75_to_sm80 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm80_to_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm89_to_sm90 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
|
||||
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
FP8MathOperator>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
||||
float, cutlass::layout::RowMajor, AlignmentCD,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
MainLoopStages, Operator,
|
||||
1 /* epilogue stages */
|
||||
>::GemmKernel>;
|
||||
// clang-format on
|
||||
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_caller(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.dims()[0];
|
||||
int32_t n = b.dims()[0];
|
||||
int32_t k = a.dims()[1];
|
||||
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||
|
||||
int64_t lda = a.strides()[0];
|
||||
int64_t ldb = b.strides()[0];
|
||||
int64_t ldc = out.strides()[0];
|
||||
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data());
|
||||
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
using Epilogue = typename Gemm::Epilogue;
|
||||
auto evt_args =
|
||||
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
typename Gemm::Op::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||
problem_size, // problem size
|
||||
1, // batch count
|
||||
epilogue_args,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldc};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
phi::Allocator *allocator = paddle::GetAllocator(a.place());
|
||||
auto workspace = allocator->Allocate(workspace_size);
|
||||
|
||||
auto stream = a.stream();
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace->ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
inline void fallback_cutlass_gemm_caller(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
// the FallbackGemm instead.
|
||||
static const int max_shared_mem_per_block_opt_in =
|
||||
get_cuda_max_shared_memory_per_block_opt_in(0);
|
||||
|
||||
size_t const gemm_shared_mem_size =
|
||||
sizeof(typename Gemm::KernelType::SharedStorage);
|
||||
size_t const fallback_gemm_shared_mem_size =
|
||||
sizeof(typename FallbackGemm::KernelType::SharedStorage);
|
||||
|
||||
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
PD_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,125 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM75 based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (256, inf]
|
||||
// - M in (64, 128]
|
||||
// Shared memory required by this Gemm 32768
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M256 {
|
||||
// M in (128, 256]
|
||||
// Shared memory required by this Gemm 65536
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M64 {
|
||||
// M in (32, 64]
|
||||
// Shared memory required by this Gemm 49152
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M32 {
|
||||
// M in [1, 32]
|
||||
// Shared memory required by this Gemm 49152
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm75_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM256 =
|
||||
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm75_config_default has the least shared-memory requirements.
|
||||
using FallbackGemm = Cutlass2xGemmDefault;
|
||||
|
||||
uint32_t const m = a.dims()[0];;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 32) {
|
||||
// M in [1, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,141 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM80 based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (128, inf)
|
||||
// - M in (64, 128] and N >= 8192
|
||||
// Shared Memory required by this Gemm - 81920 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M64 {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (32, 64]
|
||||
// - M in (64, 128] and N < 8192
|
||||
// Shared Memory required by this Gemm - 122880 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M32 {
|
||||
// M in (16, 32]
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M16 {
|
||||
// M in [1, 16]
|
||||
// Shared Memory required by this Gemm - 51200 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm80_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128BigN =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128SmallN =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM16 =
|
||||
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm80_config_M16 has the least shared-memory requirement. However,
|
||||
// based on some profiling, we select sm80_config_M32 as a better alternative
|
||||
// performance wise.
|
||||
using FallbackGemm =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const m = a.dims()[0];;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
uint32_t const n = out.dims()[1];;
|
||||
bool const small_n = n < 8192;
|
||||
if (small_n) {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
|
||||
FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
// M in (128, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,370 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_fp8_fallback_gemm {
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5,
|
||||
FP8MathOperator>;
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8196) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
static const int32_t MainLoopStages = 5;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 24576) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
|
||||
uint32_t const m = a.dims()[0];;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,355 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_int8_fallback_gemm {
|
||||
// Shared mem requirement : 61440
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
static int32_t const MainLoopStages = 5;
|
||||
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
struct sm89_int8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[1];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.dims()[0];
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
|
||||
|
||||
return fastdeploy::fallback_cutlass_gemm_caller<
|
||||
fastdeploy::cutlass_2x_gemm<cutlass::arch::Sm89, fastdeploy::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(paddle::Tensor& out,
|
||||
paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
PD_CHECK(a.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(b.dtype() == paddle::DataType::INT8);
|
||||
|
||||
uint32_t const m = a.dims()[0];
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
@@ -0,0 +1,37 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
|
||||
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
fastdeploy::cutlass_scaled_mm_sm90_fp8,
|
||||
fastdeploy::cutlass_scaled_mm_sm90_int8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& out, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32);
|
||||
|
||||
fastdeploy::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
224
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu
Normal file
224
custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,224 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
|
||||
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include <iostream>
|
||||
|
||||
void cutlass_scaled_mm_sm75(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b,
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a,
|
||||
paddle::Tensor const &b, paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
// Checks for conformality
|
||||
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
|
||||
c.dims().size() == 2);
|
||||
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
|
||||
b.dims()[0] == c.dims()[1]);
|
||||
|
||||
// Check for strides and alignment
|
||||
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
|
||||
PD_CHECK(b.strides()[1] == 1); // Column-major
|
||||
PD_CHECK(c.strides()[0] % 16 == 0 &&
|
||||
b.strides()[0] % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous() &&
|
||||
bias->dims().size() == 1);
|
||||
}
|
||||
|
||||
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: %d",
|
||||
version_num));
|
||||
}
|
||||
|
||||
void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a,
|
||||
paddle::Tensor const& b,
|
||||
paddle::Tensor const& a_scales,
|
||||
paddle::Tensor const& b_scales,
|
||||
paddle::Tensor const& azp_adj,
|
||||
paddle::optional<paddle::Tensor> const& azp,
|
||||
paddle::optional<paddle::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
PD_CHECK(a.dims().size() == 2 && b.dims().size() == 2 &&
|
||||
c.dims().size() == 2);
|
||||
PD_CHECK(c.dims()[0] == a.dims()[0] && a.dims()[1] == b.dims()[1] &&
|
||||
b.dims()[0] == c.dims()[1]);
|
||||
PD_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]);
|
||||
PD_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.dims()[0]);
|
||||
|
||||
// Check for strides and alignment
|
||||
PD_CHECK(a.strides()[1] == 1 && c.strides()[1] == 1); // Row-major
|
||||
PD_CHECK(b.strides()[1] == 1); // Column-major
|
||||
PD_CHECK(c.strides()[0] % 16 == 0 &&
|
||||
b.strides()[0] % 16 == 0); // 16 Byte Alignment
|
||||
PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
PD_CHECK(bias->numel() == b.dims()[0] && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
PD_CHECK(azp->numel() == a.dims()[0] && azp->is_contiguous());
|
||||
}
|
||||
PD_CHECK(azp_adj.numel() == b.dims()[0] && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
PD_CHECK(azp_adj.dtype() == paddle::DataType::INT32);
|
||||
PD_CHECK(!azp || azp->dtype() == paddle::DataType::INT32);
|
||||
PD_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
int32_t version_num = GetGPUComputeCapability(a.place().GetDeviceId());
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
PD_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: %d",
|
||||
version_num));
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(cutlass_scaled_mm)
|
||||
.Inputs({"c", "a", "b", "a_scales", "b_scales", paddle::Optional("bias")})
|
||||
.Outputs({"c_out"})
|
||||
.SetInplaceMap({{"c", "c_out"}})
|
||||
.SetKernelFn(PD_KERNEL(CutlassScaledMm));
|
||||
|
||||
PD_BUILD_STATIC_OP(cutlass_scaled_mm_azp)
|
||||
.Inputs({"c", "a", "b", "a_scales", "b_scales", "azp_adj", paddle::Optional("azp"), paddle::Optional("bias")})
|
||||
.Outputs({"c_out"})
|
||||
.SetInplaceMap({{"c", "c_out"}})
|
||||
.SetKernelFn(PD_KERNEL(CutlassScaledMmAzp));
|
||||
@@ -1,27 +0,0 @@
|
||||
# DeepGEMM
|
||||
|
||||
DeepGEMM 安装流程
|
||||
|
||||
## Installation
|
||||
|
||||
首先安装自定义算子,确保cutlass已经`git clone`到[custom_ops/third_party/cutlass](../../third_party/cutlass)
|
||||
|
||||
安装deep_gemm:
|
||||
|
||||
```bash
|
||||
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
|
||||
python setup.py develop
|
||||
|
||||
# Add the project path to PYTHONPATH
|
||||
export PYTHONPATH=$(pwd):$PYTHONPATH
|
||||
|
||||
# or install directly
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
```bash
|
||||
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
|
||||
python tests/test_core.py
|
||||
```
|
||||
@@ -1,31 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""This module contains all JIT kernels used in GEMM"""
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
ceil_div,
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_col_major_tma_aligned_tensor_prefill,
|
||||
get_m_alignment_for_contiguous_layout,
|
||||
get_num_sms,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
set_num_sms,
|
||||
)
|
||||
from .utils import bench, calc_diff, get_cuda_home
|
||||
@@ -1,462 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The file has been adapted from DeepSeek DeepGEMM project
|
||||
// Copyright (c) 2025 DeepSeek
|
||||
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
min(BLOCK_M, shape_m), BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@@ -1,903 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The file has been adapted from DeepSeek DeepGEMM project
|
||||
// Copyright (c) 2025 DeepSeek
|
||||
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %10, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
" %8,"
|
||||
" %9,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 16;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %14, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11},"
|
||||
" %12,"
|
||||
" %13,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 24;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %18, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
||||
" %16,"
|
||||
" %17,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 32;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 40;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %26, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
||||
" %24,"
|
||||
" %25,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 48;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %30, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27}, "
|
||||
" %28,"
|
||||
" %29,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 56;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %34, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
||||
" %32,"
|
||||
" %33,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 64;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %38, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35}, "
|
||||
" %36,"
|
||||
" %37,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 72;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %42, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
||||
" %40,"
|
||||
" %41,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 80;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %46, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43}, "
|
||||
" %44,"
|
||||
" %45,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 88;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %50, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
||||
" %48,"
|
||||
" %49,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 96;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %54, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51}, "
|
||||
" %52,"
|
||||
" %53,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 104;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %58, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
||||
" %56,"
|
||||
" %57,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 112;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %62, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59}, "
|
||||
" %60,"
|
||||
" %61,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 120;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %66, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
||||
" %64,"
|
||||
" %65,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 128;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %98, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
||||
" %96,"
|
||||
" %97,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 192;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct FP8MMASelector {
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The file has been adapted from DeepSeek DeepGEMM project
|
||||
// Copyright (c) 2025 DeepSeek
|
||||
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,116 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The file has been adapted from DeepSeek DeepGEMM project
|
||||
// Copyright (c) 2025 DeepSeek
|
||||
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
/*
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
*/
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
//#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,66 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The file has been adapted from DeepSeek DeepGEMM project
|
||||
// Copyright (c) 2025 DeepSeek
|
||||
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""Compiler module"""
|
||||
from .compiler import build, get_nvcc_compiler
|
||||
from .runtime import Runtime
|
||||
from .template import cpp_format, generate
|
||||
@@ -1,208 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""compiler"""
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from typing import Tuple
|
||||
|
||||
from ..utils import get_cuda_home
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
|
||||
def hash_to_hex(s: str) -> str:
|
||||
"""Hash string s into hexadecimal format"""
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode("utf-8"))
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
"""Get jit include dir"""
|
||||
return f"{os.path.dirname(os.path.abspath(__file__))}/../include"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
"""Get deep gemm version"""
|
||||
# Update include directories
|
||||
include_dir = f"{get_jit_include_dir()}/deep_gemm"
|
||||
assert os.path.exists(
|
||||
include_dir
|
||||
), f"Cannot find GEMM include directory {include_dir}"
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(
|
||||
lambda x: x.endswith(".cuh"), sorted(os.listdir(include_dir))
|
||||
):
|
||||
with open(f"{include_dir}/{filename}", "rb") as f:
|
||||
md5.update(f.read())
|
||||
|
||||
# Update `interleave_ffma.py`
|
||||
with open(
|
||||
f"{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py", "rb"
|
||||
) as f:
|
||||
md5.update(f.read())
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
"""Get default NVCC compiler"""
|
||||
paths = []
|
||||
if os.getenv("DG_NVCC_COMPILER"):
|
||||
paths.append(os.getenv("DG_NVCC_COMPILER"))
|
||||
CUDA_HOME = get_cuda_home()
|
||||
paths.append(f"{CUDA_HOME}/bin/nvcc")
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = "12.3"
|
||||
version_pattern = re.compile(r"release (\d+\.\d+)")
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f"{path} --version").read())
|
||||
version = match.group(1)
|
||||
assert match, f"Cannot get the version of NVCC compiler {path}"
|
||||
assert (
|
||||
version >= least_version_required
|
||||
), f"NVCC {path} version {version} is lower than {least_version_required}"
|
||||
return path, version
|
||||
raise RuntimeError("Cannot find any available NVCC compiler")
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
"""Get default user dir"""
|
||||
if "DG_CACHE_DIR" in os.environ:
|
||||
path = os.getenv("DG_CACHE_DIR")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser("~") + "/.deep_gemm"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
"""Get temporary dir"""
|
||||
return f"{get_default_user_dir()}/tmp"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
"""Get cache dir"""
|
||||
return f"{get_default_user_dir()}/cache"
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
"""Make temporary dir"""
|
||||
tmp_dir = get_tmp_dir()
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
"""Put data to specified path"""
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f"{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}"
|
||||
with open(tmp_file_path, "wb" if is_binary else "w") as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
"""Build kernel"""
|
||||
# Compiler flags
|
||||
nvcc_flags = [
|
||||
"-std=c++17",
|
||||
"-shared",
|
||||
"-O3",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"-gencode=arch=compute_90a,code=sm_90a",
|
||||
"--ptxas-options=--register-usage-level=10"
|
||||
+ (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
"--diag-suppress=177,174,940",
|
||||
]
|
||||
cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi"]
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = (
|
||||
get_nvcc_compiler()[1] <= "12.8"
|
||||
and int(os.getenv("DG_DISABLE_FFMA_INTERLEAVE", 0)) == 0
|
||||
)
|
||||
signature = f"{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}"
|
||||
name = f"kernel.{name}.{hash_to_hex(signature)}"
|
||||
path = f"{get_cache_dir()}/{name}"
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(f"Using cached JIT runtime {name} during build")
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f"{path}/kernel.args"
|
||||
src_path = f"{path}/kernel.cu"
|
||||
put(
|
||||
args_path,
|
||||
", ".join(
|
||||
[f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]
|
||||
),
|
||||
)
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f"{path}/kernel.so"
|
||||
tmp_so_path = (
|
||||
f"{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so"
|
||||
)
|
||||
# Compile
|
||||
command = [
|
||||
get_nvcc_compiler()[0],
|
||||
src_path,
|
||||
"-o",
|
||||
tmp_so_path,
|
||||
*flags,
|
||||
*[f"-I{d}" for d in include_dirs],
|
||||
]
|
||||
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_JIT_PRINT_NVCC_COMMAND", False):
|
||||
print(f"Compiling JIT runtime {name} with command {command}")
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f"Failed to compile {src_path}"
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_so_path)
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
@@ -1,173 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""interleave ffma"""
|
||||
import argparse
|
||||
import mmap
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from ..utils import get_cuda_home
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
"""Run cuobjdump on the given file path and returns its output as a string."""
|
||||
CUDA_HOME = get_cuda_home()
|
||||
command = [f"{CUDA_HOME}/bin/cuobjdump", "-sass", file_path]
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
assert result.returncode == 0
|
||||
return result.stdout
|
||||
|
||||
|
||||
def extract_ffma(sass):
|
||||
"""Extract all FFMA instructions from the SASS code."""
|
||||
lines = sass.splitlines()
|
||||
collected = []
|
||||
current = []
|
||||
|
||||
arch_name, func_name = "N/A", "N/A"
|
||||
skip_next_line = False
|
||||
for line in lines:
|
||||
if "code for" in line:
|
||||
arch_name = line.lstrip().lstrip("code for ").rstrip()
|
||||
elif "Function :" in line:
|
||||
func_name = line.lstrip().lstrip("Function :").rstrip()
|
||||
elif "FFMA" in line:
|
||||
current.append(line)
|
||||
skip_next_line = True
|
||||
elif skip_next_line:
|
||||
current.append(line)
|
||||
skip_next_line = False
|
||||
else:
|
||||
if len(current) >= 16:
|
||||
assert len(current) % 2 == 0
|
||||
collected.append((f"{arch_name}::{func_name}", current))
|
||||
current = []
|
||||
|
||||
if os.getenv("DG_PRINT_REG_REUSE", None):
|
||||
print(f"Found {len(collected)} FFMA segments")
|
||||
return collected
|
||||
|
||||
|
||||
def extract_hex_from_line(line):
|
||||
"""Extract hexadecimal number from the given line using regular expression."""
|
||||
match = re.search(r"/\*\s*(0x[0-9a-fA-F]+)\s*\*/", line)
|
||||
assert match
|
||||
return int(match.group(1), 16)
|
||||
|
||||
|
||||
def validate(m, offset, le_bytes, num_lines):
|
||||
"""Validate that the memory region contains the expected bytes starting from the specified offset"""
|
||||
assert len(le_bytes) == num_lines // 2
|
||||
assert m[offset : offset + 16] == le_bytes[0]
|
||||
for i in range(1, num_lines // 2):
|
||||
if m[offset + i * 16 : offset + i * 16 + 16] != le_bytes[i]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_registers(line):
|
||||
"""Parse register names from the given line"""
|
||||
line = re.sub(r"/\*.*?\*/", "", line)
|
||||
line = line.replace(";", "")
|
||||
tokens = line.strip().split(",")
|
||||
registers = []
|
||||
for token in tokens:
|
||||
token = token.strip()
|
||||
words = token.split()
|
||||
for word in words:
|
||||
if word.startswith("R"):
|
||||
reg = word.split(".")[0]
|
||||
registers.append(reg)
|
||||
return registers
|
||||
|
||||
|
||||
def modify_segment(m, name, ffma_lines):
|
||||
"""Modify the segment so it can be used multiple times"""
|
||||
num_lines = len(ffma_lines)
|
||||
assert num_lines % 2 == 0
|
||||
|
||||
le_bytes, new_le_bytes = [], []
|
||||
reused_list = []
|
||||
dst_reg_set = set()
|
||||
last_reused, last_dst_reg = False, ""
|
||||
num_changed = 0
|
||||
for i in range(num_lines // 2):
|
||||
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
|
||||
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
|
||||
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(
|
||||
high_line
|
||||
)
|
||||
le_bytes.append(low_hex.to_bytes(8, "little") + high_hex.to_bytes(8, "little"))
|
||||
reused = (high_hex & 0x0800000000000000) != 0
|
||||
if reused:
|
||||
is_first_occurred = dst_reg not in dst_reg_set
|
||||
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
||||
# Modify the `reuse` and `yield` bits
|
||||
assert high_hex & 0x0800200000000000, f"{hex(high_hex)}"
|
||||
high_hex ^= 0x0800200000000000
|
||||
reused = False
|
||||
num_changed += 1
|
||||
else:
|
||||
reused_list.append(i)
|
||||
dst_reg_set.add(dst_reg)
|
||||
new_le_bytes.append(
|
||||
low_hex.to_bytes(8, "little") + high_hex.to_bytes(8, "little")
|
||||
)
|
||||
last_reused, last_dst_reg = reused, dst_reg
|
||||
if os.getenv("DG_PRINT_REG_REUSE", None):
|
||||
print(
|
||||
f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}"
|
||||
)
|
||||
|
||||
# Find the offset
|
||||
offsets = []
|
||||
offset = m.find(le_bytes[0])
|
||||
while offset != -1:
|
||||
offsets.append(offset)
|
||||
offset = m.find(le_bytes[0], offset + 1)
|
||||
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
|
||||
|
||||
# Replace with `new_le_bytes`
|
||||
for offset in offsets:
|
||||
for i in range(num_lines // 2):
|
||||
m[offset + i * 16 : offset + i * 16 + 16] = new_le_bytes[i]
|
||||
|
||||
|
||||
def process(path):
|
||||
"""Process the given path"""
|
||||
if os.getenv("DG_PRINT_REG_REUSE", None):
|
||||
print(f"Processing {path}")
|
||||
output = run_cuobjdump(path)
|
||||
segments = extract_ffma(output)
|
||||
with open(path, "r+b") as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
|
||||
for segment in segments:
|
||||
modify_segment(mm, *segment)
|
||||
mm.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description="Interleave FFMA reg reuse")
|
||||
parser.add_argument("--so", help="Path to the SO file")
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args.so)
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""Runtime"""
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""A callable class that wraps CUDA kernel execution"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
def is_path_valid(path: str) -> bool:
|
||||
"""Check whether the given path contains all necessary files"""
|
||||
# Exists and is a directory
|
||||
if not os.path.exists(path) or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ["kernel.cu", "kernel.args", "kernel.so"]
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
"""Call the wrapped function"""
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, "kernel.so"))
|
||||
with open(os.path.join(self.path, "kernel.args"), "r") as f:
|
||||
self.args = eval(f.read(), {"paddle": paddle})
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(
|
||||
self.args
|
||||
), f"Expected {len(self.args)} arguments, got {len(args)}"
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, Tensor):
|
||||
assert (
|
||||
arg.dtype == dtype
|
||||
), f"Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`"
|
||||
else:
|
||||
assert isinstance(
|
||||
arg, dtype
|
||||
), f"Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`"
|
||||
cargs.append(map_ctype(arg))
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
"""A cache for Runtimes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
"""Get a cached Runtime"""
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
"""Set a new Runtime"""
|
||||
self.cache[path] = runtime
|
||||
@@ -1,150 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""Template"""
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
paddle.int32: "paddle.int32",
|
||||
paddle.float32: "paddle.float32",
|
||||
paddle.bfloat16: "paddle.bfloat16",
|
||||
paddle.float8_e4m3fn: "paddle.float8_e4m3fn",
|
||||
paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f"c_{t.__name__}") for t in (bool, int, float)},
|
||||
**{
|
||||
t: ctypes.c_void_p
|
||||
for t in (
|
||||
paddle.int32,
|
||||
paddle.float32,
|
||||
paddle.bfloat16,
|
||||
paddle.float8_e4m3fn,
|
||||
paddle.device.cuda.Stream,
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ("bool", "bool"),
|
||||
int: ("int", "int"),
|
||||
float: ("float", "float"),
|
||||
paddle.int32: ("void*", "int*"),
|
||||
paddle.float32: ("void*", "float*"),
|
||||
paddle.bfloat16: ("void*", "__nv_bfloat16*"),
|
||||
paddle.float8_e4m3fn: ("void*", "__nv_fp8_e4m3*"),
|
||||
paddle.device.cuda.Stream: ("void*", "cudaStream_t"),
|
||||
}
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
"""Map python types to corresponding ctypes"""
|
||||
ctype = ctype_map[value.dtype if isinstance(value, Tensor) else type(value)]
|
||||
if isinstance(value, Tensor):
|
||||
return ctype(value.data_ptr())
|
||||
if isinstance(value, paddle.device.cuda.Stream):
|
||||
return ctype(value.cuda_stream)
|
||||
return ctype(value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
"""Format template string using given dict"""
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
new_template = new_template.replace(f"{{{key}}}", f"{value}")
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
"""Generate CPP source code"""
|
||||
# Common prefix
|
||||
code = "// DeepGEMM auto-generated JIT CUDA source file\n\n"
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = [
|
||||
"<cuda.h>",
|
||||
"<cuda_fp8.h>",
|
||||
"<cuda_runtime.h>",
|
||||
"<iostream>",
|
||||
]
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(
|
||||
list(
|
||||
set(
|
||||
preload_sys_includes
|
||||
+ [include for include in includes if include.startswith("<")]
|
||||
)
|
||||
)
|
||||
)
|
||||
package_includes = sorted(
|
||||
list(
|
||||
set(
|
||||
preload_package_includes
|
||||
+ [include for include in includes if include.startswith('"')]
|
||||
)
|
||||
)
|
||||
)
|
||||
code += "\n".join(f"#include {include}" for include in sys_includes) + "\n\n"
|
||||
code += "\n".join(f"#include {include}" for include in package_includes) + "\n\n"
|
||||
|
||||
# Function signature
|
||||
raw = "__raw_"
|
||||
get_def = (
|
||||
lambda n, t: f"{genc_map[t][0]} "
|
||||
+ (raw if genc_map[t][0] != genc_map[t][1] else "")
|
||||
+ n
|
||||
)
|
||||
code += 'extern "C" void launch('
|
||||
code += ", ".join(
|
||||
[get_def(*arg_def) for arg_def in arg_defs]
|
||||
+ [
|
||||
"int& __return_code",
|
||||
]
|
||||
)
|
||||
code += ") {\n"
|
||||
|
||||
# Cast raw types
|
||||
code += " // Cast raw types (if needed)\n"
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f" auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n"
|
||||
|
||||
# Function body
|
||||
code += "\n".join([((" " if line else "") + line) for line in body.split("\n")])
|
||||
|
||||
# End the function
|
||||
code += "}\n\n"
|
||||
|
||||
# Debug print
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(f"Generated code:\n{code}")
|
||||
|
||||
return code
|
||||
@@ -1,31 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""initialize"""
|
||||
from .gemm import gemm_fp8_fp8_bf16_nt
|
||||
from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
from .utils import (
|
||||
ceil_div,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_col_major_tma_aligned_tensor_prefill,
|
||||
get_m_alignment_for_contiguous_layout,
|
||||
get_num_sms,
|
||||
set_num_sms,
|
||||
)
|
||||
@@ -1,266 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""FP8 GEMM kernels"""
|
||||
import functools
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import (
|
||||
ceil_div,
|
||||
get_m_alignment_for_contiguous_layout,
|
||||
get_num_sms,
|
||||
)
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"',)
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(
|
||||
n: int, block_n: int, num_tma_multicast: int, num_sms: int
|
||||
) -> bool:
|
||||
"""Check whether it's legal to have multiple multicasts per SM."""
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_smem_size(
|
||||
num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128
|
||||
) -> int:
|
||||
"""Get shared memory size needed for each stage"""
|
||||
smem_d = block_m * block_n * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
return smem_size
|
||||
|
||||
|
||||
def get_best_configs(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_sms: int,
|
||||
is_grouped_contiguous: bool = False,
|
||||
) -> Tuple[int, int, int, int, int]:
|
||||
"""Find the optimal configuration"""
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128,)
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(),)
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (
|
||||
ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms)
|
||||
if bm
|
||||
else None
|
||||
)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate(
|
||||
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms
|
||||
)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(
|
||||
best_block_m, best_block_n
|
||||
)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||
success = util > best_util or (
|
||||
util == best_util
|
||||
and (
|
||||
block_m > best_block_m
|
||||
or (block_m == best_block_m and block_n < best_block_n)
|
||||
)
|
||||
)
|
||||
best_block_m, best_block_n = (
|
||||
(block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
if (
|
||||
m >= 1024
|
||||
and is_tma_multicast_legal(n, best_block_n, 2, num_sms)
|
||||
and num_groups == 1
|
||||
):
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
return (
|
||||
best_block_m,
|
||||
best_block_n,
|
||||
best_num_stages,
|
||||
best_num_tma_multicast,
|
||||
best_smem_size,
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def auto_tuning_with_compilation(m, n, k):
|
||||
"""Compile and tune the GEMM"""
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
|
||||
m, n, k, 1, num_sms
|
||||
)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
name="gemm_fp8_fp8_bf16_nt",
|
||||
keys={
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"K": k,
|
||||
"N": n,
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": num_tma_multicast,
|
||||
},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(
|
||||
("lhs", paddle.float8_e4m3fn),
|
||||
("lhs_scales", paddle.float32),
|
||||
("rhs", paddle.float8_e4m3fn),
|
||||
("rhs_scales", paddle.float32),
|
||||
("out", paddle.bfloat16),
|
||||
("m", int),
|
||||
("stream", paddle.device.cuda.Stream),
|
||||
("num_sms", int),
|
||||
("smem_size", int),
|
||||
),
|
||||
template=template,
|
||||
)
|
||||
return runtime, num_sms, smem_size
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(
|
||||
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor
|
||||
) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow Paddle operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
# m_, n_ = out.shape
|
||||
# assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
# assert m == m_ and n == n_ and k == k_
|
||||
# assert n > 0 and k > 0
|
||||
# assert lhs_scales.shape == [m, (k + 127) // 128]
|
||||
# assert rhs_scales.shape == [(n + 127) // 128, (k + 127) // 128]
|
||||
# assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
||||
# assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
# assert out.dtype == paddle.bfloat16
|
||||
# assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
# TODO: NOT NEED get_col_major_tma_aligned_tensor!!!
|
||||
# lhs_scales = get_col_major_tma_aligned_tensor_prefill(lhs_scales)
|
||||
# assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
runtime, num_sms, smem_size = auto_tuning_with_compilation(m, n, k)
|
||||
args = (
|
||||
lhs,
|
||||
lhs_scales,
|
||||
rhs,
|
||||
rhs_scales,
|
||||
out,
|
||||
m,
|
||||
paddle.device.cuda.current_stream(),
|
||||
num_sms,
|
||||
smem_size,
|
||||
)
|
||||
# Run the kernel.
|
||||
runtime(*args)
|
||||
@@ -1,329 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""m grouped gemm"""
|
||||
import functools
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_col_major_tma_aligned_tensor_prefill,
|
||||
get_num_sms,
|
||||
)
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"',)
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""
|
||||
A LRUCache
|
||||
"""
|
||||
def __init__(self, capacity: int):
|
||||
self.capacity = capacity
|
||||
self.cache = OrderedDict()
|
||||
|
||||
def get(self, key):
|
||||
"""
|
||||
get keys from lru cache
|
||||
"""
|
||||
if key in self.cache:
|
||||
# 若键存在,将其移动到 OrderedDict 的末尾,表示最近使用
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
return None
|
||||
|
||||
def put(self, key, value):
|
||||
"""
|
||||
put keys in lru cache
|
||||
"""
|
||||
if key in self.cache:
|
||||
# 若键已存在,先移除
|
||||
del self.cache[key]
|
||||
elif len(self.cache) == self.capacity:
|
||||
# 若缓存已满,移除最旧的项
|
||||
self.cache.popitem(last=False)
|
||||
# 插入新的键值对,并将其置于 OrderedDict 的末尾
|
||||
self.cache[key] = value
|
||||
|
||||
|
||||
grouped_gemm_masked_keys = LRUCache(10)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, num_sms):
|
||||
"""auto tuning gemm"""
|
||||
global includes, template
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
||||
)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
||||
keys={
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"GEMM_TYPE": "GroupedContiguous",
|
||||
"K": k,
|
||||
"N": n,
|
||||
"NUM_GROUPS": num_groups,
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": num_tma_multicast,
|
||||
},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(
|
||||
("lhs", paddle.float8_e4m3fn),
|
||||
("lhs_scales", paddle.float32),
|
||||
("rhs", paddle.float8_e4m3fn),
|
||||
("rhs_scales", paddle.float32),
|
||||
("out", paddle.bfloat16),
|
||||
("grouped_layout", paddle.int32),
|
||||
("m", int),
|
||||
("num_groups", int),
|
||||
("stream", paddle.device.cuda.Stream),
|
||||
("num_sms", int),
|
||||
("smem_size", int),
|
||||
),
|
||||
template=template,
|
||||
)
|
||||
return runtime, num_sms, smem_size
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
lhs: Tuple[Tensor, Tensor],
|
||||
rhs: Tuple[Tensor, Tensor],
|
||||
out: Tensor,
|
||||
m_indices: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output,
|
||||
with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow Paddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
m_indices: a tensor of shape `[m_sum]` with type `paddle.int32`.
|
||||
`m_indices[i]` records the group which the j-th row of the LHS belong to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
`-1` in this tensor indicates no RHS matrix selected,
|
||||
the kernel will skip the computation for that aligned block.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
num_groups, n, k_ = rhs.shape
|
||||
# TODO: NOT NEED get_col_major_tma_aligned_tensor!!!
|
||||
lhs_scales = get_col_major_tma_aligned_tensor_prefill(lhs_scales)
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
runtime, num_sms, smem_size = auto_tuning_with_compilation_grouped_gemm_contiguous(
|
||||
m, n, k, num_groups, num_sms
|
||||
)
|
||||
|
||||
args = (
|
||||
lhs,
|
||||
lhs_scales,
|
||||
rhs,
|
||||
rhs_scales,
|
||||
out,
|
||||
m_indices,
|
||||
m,
|
||||
num_groups,
|
||||
paddle.device.cuda.current_stream(),
|
||||
num_sms,
|
||||
smem_size,
|
||||
)
|
||||
runtime(*args)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
lhs: Tuple[Tensor, Tensor],
|
||||
rhs: Tuple[Tensor, Tensor],
|
||||
out: Tensor,
|
||||
masked_m: Tensor,
|
||||
expected_m: int,
|
||||
) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow Paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
in the i-th group.
|
||||
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
|
||||
correctly setting this value may lead to better performance.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
num_groups, m, k = lhs.shape
|
||||
num_groups_, n, k_ = rhs.shape
|
||||
# num_groups__, m_, n_ = out.shape
|
||||
# assert len(masked_m.shape) == 1
|
||||
# num_groups___ = masked_m.shape[0]
|
||||
|
||||
# Type and shape checks
|
||||
# assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
# assert m == m_ and n == n_ and k == k_
|
||||
# assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
# assert lhs_scales.shape == [num_groups, m, (k + 127) // 128]
|
||||
# assert rhs_scales.shape == [num_groups, (n + 127) // 128, (k + 127) // 128]
|
||||
# assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
||||
# assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
# assert out.dtype == paddle.bfloat16
|
||||
# assert masked_m.dtype == paddle.int32
|
||||
# assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
# assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
# assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
|
||||
input_keys = (expected_m, n, k, num_groups, num_sms)
|
||||
if grouped_gemm_masked_keys.get(input_keys) is None:
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms
|
||||
)
|
||||
|
||||
args = (
|
||||
lhs,
|
||||
lhs_scales,
|
||||
rhs,
|
||||
rhs_scales,
|
||||
out,
|
||||
masked_m,
|
||||
m,
|
||||
paddle.device.cuda.current_stream(),
|
||||
num_sms,
|
||||
smem_size,
|
||||
)
|
||||
|
||||
runtime = jit_tuner.compile_and_tune_group_gemm_masked(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
||||
keys={
|
||||
"N": n,
|
||||
"K": k,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"NUM_GROUPS": num_groups,
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": num_tma_multicast,
|
||||
"GEMM_TYPE": "GroupedMasked",
|
||||
},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(
|
||||
("lhs", paddle.float8_e4m3fn),
|
||||
("lhs_scales", paddle.float32),
|
||||
("rhs", paddle.float8_e4m3fn),
|
||||
("rhs_scales", paddle.float32),
|
||||
("out", paddle.bfloat16),
|
||||
("grouped_layout", paddle.int32),
|
||||
("m", int),
|
||||
("stream", paddle.device.cuda.Stream),
|
||||
("num_sms", int),
|
||||
("smem_size", int),
|
||||
),
|
||||
template=template,
|
||||
args=args,
|
||||
)
|
||||
|
||||
grouped_gemm_masked_keys.put(input_keys, (runtime, smem_size))
|
||||
else:
|
||||
runtime, smem_size = grouped_gemm_masked_keys.get(input_keys)
|
||||
args = (
|
||||
lhs,
|
||||
lhs_scales,
|
||||
rhs,
|
||||
rhs_scales,
|
||||
out,
|
||||
masked_m,
|
||||
m,
|
||||
paddle.device.cuda.current_stream(),
|
||||
num_sms,
|
||||
smem_size,
|
||||
)
|
||||
|
||||
# Extra checks for TMA store
|
||||
# if num_groups > 1 and m > block_m:
|
||||
# assert (
|
||||
# m % block_m == 0
|
||||
# ), f"For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})"
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@@ -1,181 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""tune gemm kernels"""
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import paddle
|
||||
|
||||
from ..jit import Runtime, build, cpp_format, generate
|
||||
|
||||
|
||||
class JITTuner:
|
||||
"""A tuner that compiles and auto-tunes group gemm masked kernels"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune_group_gemm_masked(
|
||||
self,
|
||||
name: str,
|
||||
keys: Dict[str, Any],
|
||||
space: tuple,
|
||||
includes: tuple,
|
||||
arg_defs: tuple,
|
||||
template: str,
|
||||
args: tuple,
|
||||
) -> Runtime:
|
||||
"""Compile and tune a group gemm masked kernel"""
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f"{keys}")
|
||||
if signature in self.tuned:
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(f"Using cached JIT kernel {name} with keys {keys}")
|
||||
return self.tuned[signature]
|
||||
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(f"Auto-tuning JIT kernel {name} with keys {keys}")
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert args is not None
|
||||
space = (dict(),) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
for tuned_keys in space:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(*args)
|
||||
if return_code != 0:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(
|
||||
f"Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: "
|
||||
f"error code {return_code}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
end_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
paddle.empty(int(256e6 // 4), dtype=paddle.int32).zero_()
|
||||
paddle.randn(
|
||||
(8192, 8192), dtype=paddle.float32, device="cuda"
|
||||
) @ paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
else:
|
||||
elapsed_time = 0
|
||||
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(
|
||||
f"Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}"
|
||||
)
|
||||
assert (
|
||||
best_runtime is not None
|
||||
), f"Failed to tune JIT kernel {name} with keys {keys}"
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_PRINT_AUTOTUNE", None):
|
||||
print(
|
||||
f"Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}"
|
||||
)
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
|
||||
def compile_and_tune(
|
||||
self,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
name: str,
|
||||
keys: Dict[str, Any],
|
||||
space: tuple,
|
||||
includes: tuple,
|
||||
arg_defs: tuple,
|
||||
template: str,
|
||||
# args: tuple,
|
||||
) -> Runtime:
|
||||
"""Compile and tune a kernel"""
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
signature = (name, m, k, n)
|
||||
if signature in self.tuned:
|
||||
return self.tuned[signature]
|
||||
# keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
# signature = (name, f"{keys}")
|
||||
# if signature in self.tuned:
|
||||
# return self.tuned[signature]
|
||||
space = (dict(),) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
for tuned_keys in space:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
elapsed_time = 0
|
||||
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv("DG_JIT_DEBUG", None):
|
||||
print(
|
||||
f"Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}"
|
||||
)
|
||||
assert (
|
||||
best_runtime is not None
|
||||
), f"Failed to tune JIT kernel {name} with keys {keys}"
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_PRINT_AUTOTUNE", None):
|
||||
print(
|
||||
f"Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}"
|
||||
)
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
@@ -1,151 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""Utility functions"""
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
def set_num_sms(num_sms: int) -> None:
|
||||
"""
|
||||
Set the maximum SM count for all GEMM kernels to use.
|
||||
|
||||
Arguments:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
assert (
|
||||
0
|
||||
< num_sms
|
||||
<= paddle.device.cuda.get_device_properties(device="cuda").multi_processor_count
|
||||
)
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
"""
|
||||
Get the current maximum limit of SM count for all GEMM kernels to use.
|
||||
If the count is never specified, the function will return the number of device SMs.
|
||||
|
||||
Returns:
|
||||
Current maximum limit of SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
_num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
"""
|
||||
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
||||
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
||||
with GEMM block shape.
|
||||
|
||||
Returns:
|
||||
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
||||
"""
|
||||
return 128
|
||||
|
||||
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor_prefill(x: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
# assert x.dim() in (2, 3)
|
||||
# remove_dim = False
|
||||
# if x.dim() == 2:
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b, m, n = x.shape
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
# if (
|
||||
# x.strides[0] == aligned_m * n
|
||||
# and x.strides[1] == 1
|
||||
# and x.strides[2] == aligned_m
|
||||
# ):
|
||||
# return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = paddle.transpose(
|
||||
paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
|
||||
)
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0)
|
||||
# return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
@@ -1,137 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""Utilities"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10, high_precision: bool = False):
|
||||
"""Benchmark function `fn` using CUDA events."""
|
||||
# Flush L2 cache with 256 MB data
|
||||
paddle.device.cuda.synchronize()
|
||||
cache = paddle.empty(int(256e6 // 4), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
# Testing
|
||||
start_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
end_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for i in range(num_tests):
|
||||
fn()
|
||||
end_event.record()
|
||||
paddle.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
def get_cuda_home():
|
||||
"""Get Cuda home directory"""
|
||||
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
||||
if cuda_home:
|
||||
return cuda_home
|
||||
|
||||
try:
|
||||
which_cmd = "which nvcc"
|
||||
|
||||
nvcc_path = os.popen(which_cmd).read().strip()
|
||||
if nvcc_path:
|
||||
return os.path.dirname(os.path.dirname(nvcc_path))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class EmptySuppress:
|
||||
"""Empty context manager"""
|
||||
def __enter__(self):
|
||||
"""Empty context manager"""
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
"""Empty exit method"""
|
||||
pass
|
||||
|
||||
|
||||
class SuppressStdoutStderr:
|
||||
"""Context manager that redirects stdout and stderr"""
|
||||
def __enter__(self):
|
||||
"""Suppress stdout and stderr"""
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
"""Restore stdout and stderr"""
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def calc_diff(x, y):
|
||||
"""Calculate difference between two vectors"""
|
||||
x, y = x.astype(paddle.float64), y.astype(paddle.float64)
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def count_bytes(tensors):
|
||||
"""Count number of bytes used by tensors"""
|
||||
total = 0
|
||||
for t in tensors:
|
||||
if isinstance(t, tuple):
|
||||
total += count_bytes(t)
|
||||
else:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
@@ -1,110 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
"""Setup script for deep_gemm"""
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
import setuptools
|
||||
from setuptools.command.build_py import build_py
|
||||
from setuptools.command.develop import develop
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
jit_include_dirs = ("deep_gemm/include/deep_gemm",)
|
||||
third_party_include_dirs = (
|
||||
"../../third_party/cutlass/include/cute",
|
||||
"../../third_party/cutlass/include/cutlass",
|
||||
)
|
||||
|
||||
|
||||
class PostDevelopCommand(develop):
|
||||
"""Custom develop command that makes symbolic links to third-party include directories"""
|
||||
def run(self):
|
||||
"""Run the custom develop command"""
|
||||
develop.run(self)
|
||||
self.make_jit_include_symlinks()
|
||||
|
||||
@staticmethod
|
||||
def make_jit_include_symlinks():
|
||||
"""Make symbolic links of jit include directories"""
|
||||
# Make symbolic links of third-party include directories
|
||||
for d in third_party_include_dirs:
|
||||
dirname = d.split("/")[-1]
|
||||
src_dir = f"{current_dir}/{d}"
|
||||
dst_dir = f"{current_dir}/deep_gemm/include/{dirname}"
|
||||
assert os.path.exists(src_dir)
|
||||
if os.path.exists(dst_dir):
|
||||
assert os.path.islink(dst_dir)
|
||||
os.unlink(dst_dir)
|
||||
os.symlink(src_dir, dst_dir, target_is_directory=True)
|
||||
|
||||
|
||||
class CustomBuildPy(build_py):
|
||||
"""Custom build command that prepares the include files before building"""
|
||||
def run(self):
|
||||
"""Run the custom build command"""
|
||||
# First, prepare the include directories
|
||||
self.prepare_includes()
|
||||
|
||||
# Then run the regular build
|
||||
build_py.run(self)
|
||||
|
||||
def prepare_includes(self):
|
||||
"""Prepare the include directories"""
|
||||
# Create temporary build directory instead of modifying package directory
|
||||
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
|
||||
os.makedirs(build_include_dir, exist_ok=True)
|
||||
|
||||
# Copy third-party includes to the build directory
|
||||
for d in third_party_include_dirs:
|
||||
dirname = d.split("/")[-1]
|
||||
src_dir = os.path.join(current_dir, d)
|
||||
dst_dir = os.path.join(build_include_dir, dirname)
|
||||
|
||||
# Remove existing directory if it exists
|
||||
if os.path.exists(dst_dir):
|
||||
shutil.rmtree(dst_dir)
|
||||
|
||||
# Copy the directory
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cmd = ["git", "rev-parse", "--short", "HEAD"]
|
||||
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
|
||||
except:
|
||||
revision = ""
|
||||
|
||||
setuptools.setup(
|
||||
name="deep_gemm",
|
||||
version="1.0.0" + revision,
|
||||
packages=["deep_gemm", "deep_gemm/jit", "deep_gemm/jit_kernels"],
|
||||
package_data={
|
||||
"deep_gemm": [
|
||||
"include/deep_gemm/**/*",
|
||||
"include/cute/**/*",
|
||||
"include/cutlass/**/*",
|
||||
]
|
||||
},
|
||||
cmdclass={
|
||||
"develop": PostDevelopCommand,
|
||||
"build_py": CustomBuildPy,
|
||||
},
|
||||
)
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The file has been adapted from DeepSeek DeepGEMM project
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
|
||||
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
import paddle
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm import (
|
||||
calc_diff,
|
||||
ceil_div,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
)
|
||||
from paddle import Tensor
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert x.dim() == 2 and x.shape[1] % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = paddle.view(x, (m, -1, 128))
|
||||
x_abs = paddle.abs(x_view).astype(paddle.float32)
|
||||
x_amax = paddle.amax(x_abs, axis=2)
|
||||
x_amax = paddle.view(x_amax, (m, -1))
|
||||
x_amax = paddle.clip(x_amax, min=1e-4)
|
||||
scaled_x = x_view * (448.0 / x_amax.unsqueeze(2))
|
||||
scaled_x_converted = paddle.view(scaled_x.astype(paddle.float8_e4m3fn), (m, n))
|
||||
|
||||
x_amax_scaled = paddle.view((x_amax / 448.0), (m, -1))
|
||||
|
||||
result = (scaled_x_converted, x_amax_scaled)
|
||||
return result
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = paddle.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = paddle.view(x_padded, (-1, 128, x_padded.shape[1] // 128, 128))
|
||||
|
||||
x_abs = paddle.abs(x_view).astype(paddle.float32)
|
||||
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
|
||||
x_amax = paddle.clip(x_amax, min=1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
|
||||
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2]))
|
||||
)
|
||||
|
||||
|
||||
def construct(
|
||||
m: int, k: int, n: int
|
||||
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor, Tensor]:
|
||||
x = paddle.randn((m, k), dtype=paddle.bfloat16)
|
||||
y = paddle.randn((n, k), dtype=paddle.bfloat16)
|
||||
out = paddle.empty((m, n), dtype=paddle.bfloat16)
|
||||
ref_out = x @ y.t()
|
||||
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_grouped(
|
||||
num_groups: int, m: int, k: int, n: int, is_masked: bool
|
||||
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor, Tensor]:
|
||||
# x_np = np.full((num_groups, m, k), 3)
|
||||
# y_np = np.full((num_groups, n, k), 2)
|
||||
# x=paddle.to_tensor(x_np).astype(paddle.bfloat16)
|
||||
# y=paddle.to_tensor(y_np).astype(paddle.bfloat16)
|
||||
x = paddle.randn((num_groups, m, k), dtype=paddle.bfloat16)
|
||||
y = paddle.randn((num_groups, n, k), dtype=paddle.bfloat16)
|
||||
out = paddle.empty((num_groups, m, n), dtype=paddle.bfloat16)
|
||||
ref_out = paddle.einsum("gmk,gnk->gmn", x, y)
|
||||
|
||||
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||
x_fp8 = (
|
||||
paddle.empty_like(x, dtype=paddle.float8_e4m3fn),
|
||||
paddle.empty((num_groups, m, k // 128), dtype=paddle.float32),
|
||||
)
|
||||
y_fp8 = (
|
||||
paddle.empty_like(y, dtype=paddle.float8_e4m3fn),
|
||||
paddle.empty((num_groups, (n + 127) // 128, k // 128), dtype=paddle.float32),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
# x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
# y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
x_fp8_0_i, x_fp8_1_i = per_token_cast_to_fp8(x[i])
|
||||
paddle.assign(x_fp8_0_i, x_fp8[0][i])
|
||||
paddle.assign(x_fp8_1_i, x_fp8[1][i])
|
||||
y_fp8_0_i, y_fp8_1_i = per_block_cast_to_fp8(y[i])
|
||||
paddle.assign(y_fp8_0_i, y_fp8[0][i])
|
||||
paddle.assign(y_fp8_1_i, y_fp8[1][i])
|
||||
|
||||
# For non-masked input, we must merge the group and M dims
|
||||
if not is_masked:
|
||||
x_fp8 = (
|
||||
paddle.view(x_fp8[0], (-1, k)),
|
||||
per_token_cast_to_fp8(paddle.view(x, (-1, k)))[1],
|
||||
)
|
||||
out, ref_out = paddle.view(out, (-1, n)), paddle.view(ref_out, (-1, n))
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def test_gemm() -> None:
|
||||
print("Testing GEMM:")
|
||||
for m in (64,):
|
||||
for k, n in [
|
||||
(7168, 2112),
|
||||
]:
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}"
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print("Testing grouped contiguous GEMM:")
|
||||
|
||||
for num_groups, m, k, n in ((4, 8192, 7168, 4096),):
|
||||
# TODO: make a stronger test
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(
|
||||
num_groups, m, k, n, is_masked=False
|
||||
)
|
||||
m_indices = paddle.arange(0, num_groups, dtype=paddle.int32)
|
||||
# m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
m_indices = paddle.flatten(
|
||||
paddle.expand(paddle.unsqueeze(m_indices, -1), shape=[num_groups, m])
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
x_fp8, y_fp8, out, m_indices
|
||||
)
|
||||
diff = calc_diff(out, ref_out)
|
||||
print("diff:", diff)
|
||||
assert diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}"
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print("Testing grouped masked GEMM:")
|
||||
|
||||
for num_groups, m in ((1, 1024),):
|
||||
for k, n in ((7168, 4096),):
|
||||
# Test correctness
|
||||
masked_m_candidates = list(
|
||||
filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))
|
||||
)
|
||||
for i in range(10):
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(
|
||||
num_groups, m, k, n, is_masked=True
|
||||
)
|
||||
masked_m = paddle.empty((num_groups,), dtype=paddle.int32)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = random.choice(masked_m_candidates)
|
||||
# expected_m = min(int(masked_m.float().mean()) + 1, m)
|
||||
masked_m_float = paddle.cast(masked_m, "float32")
|
||||
masked_m_mean = paddle.mean(masked_m_float)
|
||||
masked_m_mean_int = paddle.cast(masked_m_mean, "int32")
|
||||
expected_m = min(masked_m_mean_int + 1, m)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
x_fp8, y_fp8, out, masked_m, expected_m
|
||||
)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(
|
||||
out[j, : masked_m[j].item()], ref_out[j, : masked_m[j].item()]
|
||||
)
|
||||
print("diff:", diff)
|
||||
assert (
|
||||
diff < 0.001
|
||||
), f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}"
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.seed(0)
|
||||
random.seed(0)
|
||||
print("Library path:")
|
||||
print(f" > {deep_gemm.__path__}\n")
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
@@ -14,27 +14,21 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "fp8_common.h" // NOLINT
|
||||
#include "fp8_gemm_fused/fuse_block_gemm_act_template_3x.h"
|
||||
#include "fp8_common.h" // NOLINT
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
#include "fp8_gemm_fused/fuse_block_gemm_act_template_3x.h"
|
||||
|
||||
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::Tensor& x_scale,
|
||||
const paddle::Tensor& y_scale,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
bool trans_x,
|
||||
bool trans_y,
|
||||
std::string output_dtype,
|
||||
std::string activation_type) {
|
||||
const paddle::Tensor &x, const paddle::Tensor &y,
|
||||
const paddle::Tensor &x_scale, const paddle::Tensor &y_scale,
|
||||
const paddle::optional<paddle::Tensor> &bias, bool trans_x, bool trans_y,
|
||||
std::string output_dtype, std::string activation_type) {
|
||||
paddle::Tensor out;
|
||||
void* out_ptr = nullptr;
|
||||
const void* x_ptr = nullptr;
|
||||
const void* x_scale_ptr = nullptr;
|
||||
const void* y_ptr = nullptr;
|
||||
const void* y_scale_ptr = nullptr;
|
||||
|
||||
void *out_ptr = nullptr;
|
||||
const void *x_ptr = nullptr;
|
||||
const void *x_scale_ptr = nullptr;
|
||||
const void *y_ptr = nullptr;
|
||||
const void *y_scale_ptr = nullptr;
|
||||
|
||||
auto place = x.place();
|
||||
cudaStream_t stream = x.stream();
|
||||
@@ -46,7 +40,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
|
||||
int K = 0;
|
||||
int N = 0;
|
||||
int ldd = 0;
|
||||
|
||||
|
||||
int lda = x.dims()[rank - 1];
|
||||
int ldb = y.dims()[rank - 1];
|
||||
|
||||
@@ -72,16 +66,16 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
|
||||
}
|
||||
|
||||
std::string input_dtype = "";
|
||||
x_scale_ptr = reinterpret_cast<const void*>(x_scale.data<float>());
|
||||
y_scale_ptr = reinterpret_cast<const void*>(y_scale.data<float>());
|
||||
x_scale_ptr = reinterpret_cast<const void *>(x_scale.data<float>());
|
||||
y_scale_ptr = reinterpret_cast<const void *>(y_scale.data<float>());
|
||||
if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) {
|
||||
input_dtype = "float8_e4m3fn";
|
||||
x_ptr = reinterpret_cast<const void*>(x.data<phi::dtype::float8_e4m3fn>());
|
||||
y_ptr = reinterpret_cast<const void*>(y.data<phi::dtype::float8_e4m3fn>());
|
||||
x_ptr = reinterpret_cast<const void *>(x.data<phi::dtype::float8_e4m3fn>());
|
||||
y_ptr = reinterpret_cast<const void *>(y.data<phi::dtype::float8_e4m3fn>());
|
||||
} else if (x.dtype() == phi::DataType::FLOAT8_E5M2) {
|
||||
input_dtype = "float8_e5m2";
|
||||
x_ptr = reinterpret_cast<const void*>(x.data<phi::dtype::float8_e5m2>());
|
||||
y_ptr = reinterpret_cast<const void*>(y.data<phi::dtype::float8_e5m2>());
|
||||
x_ptr = reinterpret_cast<const void *>(x.data<phi::dtype::float8_e5m2>());
|
||||
y_ptr = reinterpret_cast<const void *>(y.data<phi::dtype::float8_e5m2>());
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"fp8_fp8_half_gemm_fused only support e4m3 and e5m2 input"));
|
||||
@@ -93,10 +87,10 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
|
||||
|
||||
if (output_dtype == "bfloat16") {
|
||||
out = paddle::empty(out_shape, paddle::DataType::BFLOAT16, x.place());
|
||||
out_ptr = reinterpret_cast<void*>(out.data<phi::dtype::bfloat16>());
|
||||
out_ptr = reinterpret_cast<void *>(out.data<phi::dtype::bfloat16>());
|
||||
} else if (output_dtype == "float16") {
|
||||
out = paddle::empty(out_shape, paddle::DataType::FLOAT16, x.place());
|
||||
out_ptr = reinterpret_cast<void*>(out.data<phi::dtype::float16>());
|
||||
out_ptr = reinterpret_cast<void *>(out.data<phi::dtype::float16>());
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"fp8_fp8_half_gemm_fused only support bfloat16 and float16 output"));
|
||||
@@ -110,84 +104,68 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_block_gemm_fused(
|
||||
std::string fuse_gemm_config =
|
||||
input_dtype + "_" + output_dtype + "_" + isbias + "_" + act;
|
||||
|
||||
void* bias_data = nullptr;
|
||||
void *bias_data = nullptr;
|
||||
std::vector<int64_t> bias_dims{};
|
||||
if (bias) {
|
||||
bias_dims = common::vectorize(bias.get().dims());
|
||||
if (output_dtype == "bfloat16") {
|
||||
bias_data = reinterpret_cast<void*>(const_cast<phi::dtype::bfloat16*>(
|
||||
bias_data = reinterpret_cast<void *>(const_cast<phi::dtype::bfloat16 *>(
|
||||
bias.get().data<phi::dtype::bfloat16>()));
|
||||
} else {
|
||||
bias_data = reinterpret_cast<void*>(const_cast<phi::dtype::float16*>(
|
||||
bias_data = reinterpret_cast<void *>(const_cast<phi::dtype::float16 *>(
|
||||
bias.get().data<phi::dtype::float16>()));
|
||||
}
|
||||
}
|
||||
|
||||
GemmEpilogueAllParams params = {x_ptr,
|
||||
y_ptr,
|
||||
out_ptr,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldd,
|
||||
batch_count,
|
||||
place,
|
||||
stream,
|
||||
sm_version,
|
||||
0.01, // for leaky_relu
|
||||
bias_data,
|
||||
bias_dims,
|
||||
fuse_gemm_config,
|
||||
0,
|
||||
nullptr,
|
||||
nullptr,
|
||||
x_scale_ptr,
|
||||
y_scale_ptr};
|
||||
|
||||
GemmEpilogueAllParams params = {x_ptr, y_ptr, out_ptr,
|
||||
1, M, N,
|
||||
K, lda, ldb,
|
||||
ldd, batch_count, place,
|
||||
stream, sm_version,
|
||||
0.01, // for leaky_relu
|
||||
bias_data, bias_dims, fuse_gemm_config,
|
||||
0, nullptr, nullptr,
|
||||
x_scale_ptr, y_scale_ptr};
|
||||
|
||||
fp8_fp8_block_gemm_scale_bias_act(params);
|
||||
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& y_shape,
|
||||
const std::vector<int64_t>& x_scale_shape,
|
||||
const std::vector<int64_t>& y_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||
bool trans_x,
|
||||
bool trans_y){
|
||||
PADDLE_ENFORCE_EQ(x_shape.size(),
|
||||
y_shape.size(),
|
||||
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
|
||||
const std::vector<int64_t> &x_scale_shape,
|
||||
const std::vector<int64_t> &y_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &bias_shape, bool trans_x,
|
||||
bool trans_y) {
|
||||
PADDLE_ENFORCE_EQ(x_shape.size(), y_shape.size(),
|
||||
phi::errors::InvalidArgument(
|
||||
"The rank of input X and Y should be equal, but received X's rank is %d, Y's rank is %d.",
|
||||
x_shape.size(),
|
||||
y_shape.size()));
|
||||
PADDLE_ENFORCE_EQ(x_shape.size(),
|
||||
x_scale_shape.size(),
|
||||
"The rank of input X and Y should be equal, but "
|
||||
"received X's rank is %d, Y's rank is %d.",
|
||||
x_shape.size(), y_shape.size()));
|
||||
PADDLE_ENFORCE_EQ(x_shape.size(), x_scale_shape.size(),
|
||||
phi::errors::InvalidArgument(
|
||||
"The rank of input X and X_scale should be equal, but received X's rank is %d, X_scale's rank is %d.",
|
||||
x_shape.size(),
|
||||
x_scale_shape.size()));
|
||||
PADDLE_ENFORCE_EQ(y_shape.size(),
|
||||
y_scale_shape.size(),
|
||||
"The rank of input X and X_scale should be equal, but "
|
||||
"received X's rank is %d, X_scale's rank is %d.",
|
||||
x_shape.size(), x_scale_shape.size()));
|
||||
PADDLE_ENFORCE_EQ(y_shape.size(), y_scale_shape.size(),
|
||||
phi::errors::InvalidArgument(
|
||||
"The rank of input Y and Y_scale should be equal, but received Y's rank is %d, Y_scale's rank is %d.",
|
||||
y_shape.size(),
|
||||
y_scale_shape.size()));
|
||||
"The rank of input Y and Y_scale should be equal, but "
|
||||
"received Y's rank is %d, Y_scale's rank is %d.",
|
||||
y_shape.size(), y_scale_shape.size()));
|
||||
int rank = x_shape.size();
|
||||
int M = 0;
|
||||
int N = 0;
|
||||
if ((x_shape[rank - 1] + 127) / 128 != x_scale_shape[rank - 2]){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-2] * 128 = x's dim[-1]."));
|
||||
if ((x_shape[rank - 1] + 127) / 128 != x_scale_shape[rank - 2]) {
|
||||
PADDLE_THROW(
|
||||
phi::errors::Fatal("cutlass_fp8_fp8_half_block_gemm_fused only support "
|
||||
"x_scale's dim[-2] * 128 = x's dim[-1]."));
|
||||
}
|
||||
if (((y_shape[rank - 1] + 127) / 128 != y_scale_shape[rank - 1]) || ((y_shape[rank - 2] + 127) / 128 != y_scale_shape[rank - 2])){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"cutlass_fp8_fp8_half_block_gemm_fused only support input y_scale's dim[-2:] * 128 = y's dim[-2:]."));
|
||||
if (((y_shape[rank - 1] + 127) / 128 != y_scale_shape[rank - 1]) ||
|
||||
((y_shape[rank - 2] + 127) / 128 != y_scale_shape[rank - 2])) {
|
||||
PADDLE_THROW(
|
||||
phi::errors::Fatal("cutlass_fp8_fp8_half_block_gemm_fused only support "
|
||||
"input y_scale's dim[-2:] * 128 = y's dim[-2:]."));
|
||||
}
|
||||
|
||||
if (!trans_x) {
|
||||
@@ -208,31 +186,25 @@ std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> CutlassFp8Fp8HalfBlockGemmFusedInferDtype(
|
||||
const paddle::DataType& x_type,
|
||||
const paddle::DataType& y_type,
|
||||
const paddle::DataType& x_scale_type,
|
||||
const paddle::DataType& y_scale_type,
|
||||
const paddle::optional<paddle::DataType>& bias_type,
|
||||
bool trans_x,
|
||||
bool trans_y,
|
||||
std::string output_dtype) {
|
||||
paddle::DataType data_type;
|
||||
if (output_dtype == "bfloat16")
|
||||
data_type = paddle::DataType::BFLOAT16;
|
||||
else if (output_dtype == "float16")
|
||||
data_type = paddle::DataType::FLOAT16;
|
||||
else
|
||||
PD_THROW(
|
||||
"cutlass_fp8_fp8_half_gemm only support bfloat16 and float16 output");
|
||||
return {data_type};
|
||||
const paddle::DataType &x_type, const paddle::DataType &y_type,
|
||||
const paddle::DataType &x_scale_type, const paddle::DataType &y_scale_type,
|
||||
const paddle::optional<paddle::DataType> &bias_type, bool trans_x,
|
||||
bool trans_y, std::string output_dtype) {
|
||||
paddle::DataType data_type;
|
||||
if (output_dtype == "bfloat16")
|
||||
data_type = paddle::DataType::BFLOAT16;
|
||||
else if (output_dtype == "float16")
|
||||
data_type = paddle::DataType::FLOAT16;
|
||||
else
|
||||
PD_THROW("cutlass_fp8_fp8_half_block_gemm only support bfloat16 and "
|
||||
"float16 output");
|
||||
return {data_type};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(cutlass_fp8_fp8_half_block_gemm_fused)
|
||||
.Inputs({"x", "y", "x_sacle", "y_scale", paddle::Optional("bias")})
|
||||
.Attrs({"transpose_x: bool",
|
||||
"transpose_y: bool",
|
||||
"output_dtype: std::string",
|
||||
"act: std::string"})
|
||||
.Attrs({"transpose_x: bool", "transpose_y: bool",
|
||||
"output_dtype: std::string", "act: std::string"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(cutlass_fp8_fp8_half_block_gemm_fused))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(CutlassFp8Fp8HalfBlockGemmFusedInferShape))
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include "helper.h"
|
||||
|
||||
namespace {
|
||||
int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) {
|
||||
int sharedMemoryOpen2(const char *name, size_t sz, sharedMemoryInfo *info) {
|
||||
info->size = sz;
|
||||
info->shmFd = shm_open(name, O_RDWR, 0777);
|
||||
if (info->shmFd < 0) {
|
||||
@@ -40,7 +40,7 @@ std::vector<paddle::Tensor> GetDataPtrIpc(const paddle::Tensor &tmp_input,
|
||||
auto out_data_ptr_tensor_ptr = out_data_ptr_tensor.data<int64_t>();
|
||||
volatile shmStruct *shm = NULL;
|
||||
sharedMemoryInfo info;
|
||||
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
|
||||
if (sharedMemoryOpen2(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
|
||||
printf("Failed to create shared memory slab\n");
|
||||
printf("Func GetDataPtrIpc. Shm_name: %s\n", shm_name.c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <map>
|
||||
|
||||
std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
|
||||
const paddle::Tensor& task_image_type_ids,
|
||||
@@ -60,6 +61,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
|
||||
st_idx += cur_st_len;
|
||||
}
|
||||
}
|
||||
|
||||
while (idx < seq_lens_origin) {
|
||||
idx = idx + split_fuse_text_size;
|
||||
if (idx >= seq_lens_origin) {
|
||||
|
||||
@@ -70,8 +70,15 @@ void GetOutputEp(const paddle::Tensor& x,
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "msg_queue_id is: "
|
||||
<< msg_queue_id << std::endl;
|
||||
#endif
|
||||
// static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
// static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "get_output_key: " << key << std::endl;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "glog/logging.h"
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
@@ -22,25 +23,25 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include "glog/logging.h"
|
||||
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hiprand.h>
|
||||
#include <hiprand_kernel.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#else
|
||||
#include <cub/cub.cuh>
|
||||
#endif
|
||||
#include "nlohmann/json.hpp"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
@@ -49,218 +50,211 @@ namespace cub = hipcub;
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
#define CUDA_CHECK(call) \
|
||||
do { \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) { \
|
||||
std::printf("at %s:%d - %s.\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
#define CUDA_CHECK(call) \
|
||||
do { \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) { \
|
||||
std::printf("at %s:%d - %s.\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
template <size_t kBlockSize = 256, size_t kNumWaves = 16>
|
||||
inline hipError_t GetNumBlocks(int64_t n, int *num_blocks) {
|
||||
int dev;
|
||||
{
|
||||
hipError_t err = hipGetDevice(&dev);
|
||||
if (err != hipSuccess) {
|
||||
return err;
|
||||
}
|
||||
int dev;
|
||||
{
|
||||
hipError_t err = hipGetDevice(&dev);
|
||||
if (err != hipSuccess) {
|
||||
return err;
|
||||
}
|
||||
int sm_count;
|
||||
{
|
||||
hipError_t err = hipDeviceGetAttribute(
|
||||
&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
|
||||
if (err != hipSuccess) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
int sm_count;
|
||||
{
|
||||
hipError_t err = hipDeviceGetAttribute(
|
||||
&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
|
||||
if (err != hipSuccess) {
|
||||
return err;
|
||||
}
|
||||
int tpm;
|
||||
{
|
||||
hipError_t err = hipDeviceGetAttribute(
|
||||
&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
|
||||
if (err != hipSuccess) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
int tpm;
|
||||
{
|
||||
hipError_t err = hipDeviceGetAttribute(
|
||||
&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
|
||||
if (err != hipSuccess) {
|
||||
return err;
|
||||
}
|
||||
*num_blocks = std::max<int>(
|
||||
1,
|
||||
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
|
||||
sm_count * tpm / kBlockSize * kNumWaves));
|
||||
return hipSuccess;
|
||||
}
|
||||
*num_blocks = std::max<int>(
|
||||
1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
|
||||
sm_count * tpm / kBlockSize * kNumWaves));
|
||||
return hipSuccess;
|
||||
}
|
||||
#else
|
||||
template <size_t kBlockSize = 256, size_t kNumWaves = 16>
|
||||
inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
|
||||
int dev;
|
||||
{
|
||||
cudaError_t err = cudaGetDevice(&dev);
|
||||
if (err != cudaSuccess) {
|
||||
return err;
|
||||
}
|
||||
int dev;
|
||||
{
|
||||
cudaError_t err = cudaGetDevice(&dev);
|
||||
if (err != cudaSuccess) {
|
||||
return err;
|
||||
}
|
||||
int sm_count;
|
||||
{
|
||||
cudaError_t err = cudaDeviceGetAttribute(
|
||||
&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
if (err != cudaSuccess) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
int sm_count;
|
||||
{
|
||||
cudaError_t err =
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
if (err != cudaSuccess) {
|
||||
return err;
|
||||
}
|
||||
int tpm;
|
||||
{
|
||||
cudaError_t err = cudaDeviceGetAttribute(
|
||||
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
|
||||
if (err != cudaSuccess) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
int tpm;
|
||||
{
|
||||
cudaError_t err = cudaDeviceGetAttribute(
|
||||
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
|
||||
if (err != cudaSuccess) {
|
||||
return err;
|
||||
}
|
||||
*num_blocks = std::max<int>(
|
||||
1,
|
||||
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
|
||||
sm_count * tpm / kBlockSize * kNumWaves));
|
||||
return cudaSuccess;
|
||||
}
|
||||
*num_blocks = std::max<int>(
|
||||
1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
|
||||
sm_count * tpm / kBlockSize * kNumWaves));
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
inline int GetGPUComputeCapability(int id) {
|
||||
int major, minor;
|
||||
auto major_error_code =
|
||||
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
|
||||
auto minor_error_code =
|
||||
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
|
||||
return major * 10 + minor;
|
||||
int major, minor;
|
||||
auto major_error_code =
|
||||
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
|
||||
auto minor_error_code =
|
||||
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
|
||||
return major * 10 + minor;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <paddle::DataType D>
|
||||
class PDTraits;
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
template <paddle::DataType D> class PDTraits;
|
||||
|
||||
template <> class PDTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef half DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
template <> class PDTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef half DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
template <> class PDTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
typedef hip_bfloat16 DataType;
|
||||
typedef hip_bfloat16 DataType;
|
||||
#else
|
||||
typedef __nv_bfloat16 DataType;
|
||||
typedef __nv_bfloat16 DataType;
|
||||
#endif
|
||||
typedef paddle::bfloat16 data_t;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::INT8> {
|
||||
public:
|
||||
typedef int8_t DataType;
|
||||
typedef int8_t data_t;
|
||||
template <> class PDTraits<paddle::DataType::INT8> {
|
||||
public:
|
||||
typedef int8_t DataType;
|
||||
typedef int8_t data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::UINT8> {
|
||||
public:
|
||||
typedef uint8_t DataType;
|
||||
typedef uint8_t data_t;
|
||||
template <> class PDTraits<paddle::DataType::UINT8> {
|
||||
public:
|
||||
typedef uint8_t DataType;
|
||||
typedef uint8_t data_t;
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
HOSTDEVICE inline const T &operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T &operator[](int i) { return val[i]; }
|
||||
HOSTDEVICE inline const T &operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T &operator[](int i) { return val[i]; }
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Load(const T *addr, AlignedVector<T, Size> *vec) {
|
||||
const AlignedVector<T, Size> *addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size> *>(addr);
|
||||
*vec = *addr_vec;
|
||||
const AlignedVector<T, Size> *addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size> *>(addr);
|
||||
*vec = *addr_vec;
|
||||
}
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
|
||||
AlignedVector<T, Size> *addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size> *>(addr);
|
||||
*addr_vec = vec;
|
||||
AlignedVector<T, Size> *addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size> *>(addr);
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
|
||||
int8_t *addr) {
|
||||
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
|
||||
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
|
||||
}
|
||||
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
|
||||
int8_t *addr) {
|
||||
printf("Error: Store half to int8_t is not supported!");
|
||||
printf("Error: Store half to int8_t is not supported!");
|
||||
}
|
||||
|
||||
constexpr int VEC_16B = 16;
|
||||
|
||||
template <typename T>
|
||||
__device__ T max_func(const T a, const T b) {
|
||||
return a > b ? a : b;
|
||||
template <typename T> __device__ T max_func(const T a, const T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||
return max_func(a, b);
|
||||
}
|
||||
template <typename T> struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||
return max_func(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
inline int GetBlockSize(int vocab_size) {
|
||||
if (vocab_size > 512) {
|
||||
return 1024;
|
||||
} else if (vocab_size > 256) {
|
||||
return 512;
|
||||
} else if (vocab_size > 128) {
|
||||
return 256;
|
||||
} else if (vocab_size > 64) {
|
||||
return 128;
|
||||
} else {
|
||||
return 64;
|
||||
}
|
||||
if (vocab_size > 512) {
|
||||
return 1024;
|
||||
} else if (vocab_size > 256) {
|
||||
return 512;
|
||||
} else if (vocab_size > 128) {
|
||||
return 256;
|
||||
} else if (vocab_size > 64) {
|
||||
return 128;
|
||||
} else {
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
|
||||
inline json readJsonFromFile(const std::string &filePath) {
|
||||
std::ifstream file(filePath);
|
||||
if (!file.is_open()) {
|
||||
throw std::runtime_error("Unable to open file: " + filePath);
|
||||
}
|
||||
std::ifstream file(filePath);
|
||||
if (!file.is_open()) {
|
||||
throw std::runtime_error("Unable to open file: " + filePath);
|
||||
}
|
||||
|
||||
json j;
|
||||
file >> j;
|
||||
return j;
|
||||
json j;
|
||||
file >> j;
|
||||
return j;
|
||||
}
|
||||
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
cudaError_t e = cudaGetLastError(); \
|
||||
if (e != cudaSuccess) { \
|
||||
std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< cudaGetErrorString(e) << std::endl; \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
cudaError_t e = cudaGetLastError(); \
|
||||
if (e != cudaSuccess) { \
|
||||
std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< cudaGetErrorString(e) << std::endl; \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
// place must be an existing place object and cannot use paddle::CPUPlace() or
|
||||
// paddle::GPUPlace()
|
||||
@@ -269,220 +263,194 @@ inline json readJsonFromFile(const std::string &filePath) {
|
||||
inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
|
||||
const paddle::DataType &dtype,
|
||||
const paddle::Place &place) {
|
||||
auto *allocator = paddle::GetAllocator(place);
|
||||
phi::DenseTensor dense_tensor;
|
||||
dense_tensor.Resize(dims);
|
||||
dense_tensor.AllocateFrom(
|
||||
allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
|
||||
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
|
||||
auto *allocator = paddle::GetAllocator(place);
|
||||
phi::DenseTensor dense_tensor;
|
||||
dense_tensor.Resize(dims);
|
||||
dense_tensor.AllocateFrom(allocator, dtype,
|
||||
dense_tensor.numel() * phi::SizeOf(dtype));
|
||||
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
|
||||
}
|
||||
|
||||
inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
|
||||
const common::DDim &strides,
|
||||
const paddle::DataType &dtype,
|
||||
const paddle::Place &place) {
|
||||
auto *allocator = paddle::GetAllocator(place);
|
||||
phi::DenseTensor dense_tensor;
|
||||
dense_tensor.Resize(dims);
|
||||
dense_tensor.AllocateFrom(
|
||||
allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
|
||||
dense_tensor.set_strides(strides);
|
||||
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
|
||||
auto *allocator = paddle::GetAllocator(place);
|
||||
phi::DenseTensor dense_tensor;
|
||||
dense_tensor.Resize(dims);
|
||||
dense_tensor.AllocateFrom(allocator, dtype,
|
||||
dense_tensor.numel() * phi::SizeOf(dtype));
|
||||
dense_tensor.set_strides(strides);
|
||||
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
|
||||
}
|
||||
#endif
|
||||
|
||||
__global__ void free_and_dispatch_block(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
__global__ void free_and_dispatch_block(
|
||||
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
|
||||
int *block_tables, int *encoder_block_lens, bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list, int *recover_len,
|
||||
int *need_block_list, int *need_block_len, int *used_list_len,
|
||||
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
|
||||
const int block_size, const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
__global__ void speculate_free_and_dispatch_block(
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
|
||||
int *block_tables, int *encoder_block_lens, bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list, int *recover_len,
|
||||
int *need_block_list, int *need_block_len, int *used_list_len,
|
||||
int *free_list, int *free_list_len, int64_t *first_token_ids,
|
||||
int *accept_num, const int bsz, const int block_size,
|
||||
const int block_num_per_seq, const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
__device__ bool speculate_free_and_dispatch_block(const int &qid,
|
||||
int *need_block_list,
|
||||
const int &need_block_len);
|
||||
|
||||
static std::string global_base64_chars = // NOLINT
|
||||
static std::string global_base64_chars = // NOLINT
|
||||
"Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg";
|
||||
|
||||
// Base64 编码函数
|
||||
inline std::string base64_encode(const std::string &input) {
|
||||
std::string ret;
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
unsigned char char_array_3[3];
|
||||
unsigned char char_array_4[4];
|
||||
std::string ret;
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
unsigned char char_array_3[3];
|
||||
unsigned char char_array_4[4];
|
||||
|
||||
for (const auto &c : input) {
|
||||
char_array_3[i++] = c;
|
||||
if (i == 3) {
|
||||
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
|
||||
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) +
|
||||
((char_array_3[1] & 0xf0) >> 4);
|
||||
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) +
|
||||
((char_array_3[2] & 0xc0) >> 6);
|
||||
char_array_4[3] = char_array_3[2] & 0x3f;
|
||||
for (const auto &c : input) {
|
||||
char_array_3[i++] = c;
|
||||
if (i == 3) {
|
||||
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
|
||||
char_array_4[1] =
|
||||
((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
|
||||
char_array_4[2] =
|
||||
((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
|
||||
char_array_4[3] = char_array_3[2] & 0x3f;
|
||||
|
||||
for (i = 0; i < 4; i++) {
|
||||
ret += global_base64_chars[char_array_4[i]];
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
for (i = 0; i < 4; i++) {
|
||||
ret += global_base64_chars[char_array_4[i]];
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (i) {
|
||||
for (j = i; j < 3; j++) {
|
||||
char_array_3[j] = '\0';
|
||||
}
|
||||
|
||||
if (i) {
|
||||
for (j = i; j < 3; j++) {
|
||||
char_array_3[j] = '\0';
|
||||
}
|
||||
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
|
||||
char_array_4[1] =
|
||||
((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
|
||||
char_array_4[2] =
|
||||
((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
|
||||
char_array_4[3] = char_array_3[2] & 0x3f;
|
||||
|
||||
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
|
||||
char_array_4[1] =
|
||||
((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
|
||||
char_array_4[2] =
|
||||
((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
|
||||
char_array_4[3] = char_array_3[2] & 0x3f;
|
||||
|
||||
for (j = 0; j < i + 1; j++) {
|
||||
ret += global_base64_chars[char_array_4[j]];
|
||||
}
|
||||
|
||||
while (i++ < 3) {
|
||||
ret += '=';
|
||||
}
|
||||
for (j = 0; j < i + 1; j++) {
|
||||
ret += global_base64_chars[char_array_4[j]];
|
||||
}
|
||||
|
||||
return ret;
|
||||
while (i++ < 3) {
|
||||
ret += '=';
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Base64 解码函数
|
||||
inline std::string base64_decode(const std::string &encoded_string) {
|
||||
int in_len = encoded_string.size();
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
int in_ = 0;
|
||||
unsigned char char_array_4[4], char_array_3[3];
|
||||
std::string ret;
|
||||
int in_len = encoded_string.size();
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
int in_ = 0;
|
||||
unsigned char char_array_4[4], char_array_3[3];
|
||||
std::string ret;
|
||||
|
||||
while (in_len-- && (encoded_string[in_] != '=') &&
|
||||
(isalnum(encoded_string[in_]) || (encoded_string[in_] == '+') ||
|
||||
(encoded_string[in_] == '/'))) {
|
||||
char_array_4[i++] = encoded_string[in_];
|
||||
in_++;
|
||||
if (i == 4) {
|
||||
for (i = 0; i < 4; i++) {
|
||||
char_array_4[i] = global_base64_chars.find(char_array_4[i]);
|
||||
}
|
||||
while (in_len-- && (encoded_string[in_] != '=') &&
|
||||
(isalnum(encoded_string[in_]) || (encoded_string[in_] == '+') ||
|
||||
(encoded_string[in_] == '/'))) {
|
||||
char_array_4[i++] = encoded_string[in_];
|
||||
in_++;
|
||||
if (i == 4) {
|
||||
for (i = 0; i < 4; i++) {
|
||||
char_array_4[i] = global_base64_chars.find(char_array_4[i]);
|
||||
}
|
||||
|
||||
char_array_3[0] =
|
||||
(char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
|
||||
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) +
|
||||
((char_array_4[2] & 0x3c) >> 2);
|
||||
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
|
||||
char_array_3[0] =
|
||||
(char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
|
||||
char_array_3[1] =
|
||||
((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
|
||||
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
|
||||
|
||||
for (i = 0; i < 3; i++) {
|
||||
ret += char_array_3[i];
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
for (i = 0; i < 3; i++) {
|
||||
ret += char_array_3[i];
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (i) {
|
||||
for (j = i; j < 4; j++) {
|
||||
char_array_4[j] = 0;
|
||||
}
|
||||
|
||||
if (i) {
|
||||
for (j = i; j < 4; j++) {
|
||||
char_array_4[j] = 0;
|
||||
}
|
||||
|
||||
for (j = 0; j < 4; j++) {
|
||||
char_array_4[j] = global_base64_chars.find(char_array_4[j]);
|
||||
}
|
||||
|
||||
char_array_3[0] =
|
||||
(char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
|
||||
char_array_3[1] =
|
||||
((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
|
||||
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
|
||||
|
||||
for (j = 0; j < i - 1; j++) {
|
||||
ret += char_array_3[j];
|
||||
}
|
||||
for (j = 0; j < 4; j++) {
|
||||
char_array_4[j] = global_base64_chars.find(char_array_4[j]);
|
||||
}
|
||||
|
||||
return ret;
|
||||
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
|
||||
char_array_3[1] =
|
||||
((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
|
||||
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
|
||||
|
||||
for (j = 0; j < i - 1; j++) {
|
||||
ret += char_array_3[j];
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T get_relative_best(nlohmann::json *json_data,
|
||||
const std::string &target_key,
|
||||
const T &default_value) {
|
||||
if (json_data->contains(target_key)) {
|
||||
return json_data->at(target_key);
|
||||
} else {
|
||||
// std::cerr << "The key " << target_key << " is not found in the JSON
|
||||
// data." << std::endl;
|
||||
return default_value;
|
||||
}
|
||||
if (json_data->contains(target_key)) {
|
||||
return json_data->at(target_key);
|
||||
} else {
|
||||
// std::cerr << "The key " << target_key << " is not found in the JSON
|
||||
// data." << std::endl;
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline bool is_in_end(const int64_t id,
|
||||
const int64_t *end_ids,
|
||||
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
|
||||
int length) {
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ __host__ T div_up(T m, T n) {
|
||||
template <typename T> inline __device__ __host__ T div_up(T m, T n) {
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __inline__
|
||||
T ClipFunc(const T v, const T min, const T max){
|
||||
if(v > max) return max;
|
||||
if(v < min) return min;
|
||||
return v;
|
||||
template <typename T>
|
||||
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
|
||||
if (v > max)
|
||||
return max;
|
||||
if (v < min)
|
||||
return min;
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -499,31 +467,49 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
if (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
|
||||
ss << static_cast<int>(tmp[i]) << std::endl;
|
||||
} else {
|
||||
ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT
|
||||
ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT
|
||||
}
|
||||
}
|
||||
outfile << ss.str();
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr, int mode = 0) {
|
||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
||||
int mode = 0) {
|
||||
uint32_t flag;
|
||||
if (mode == 0) {
|
||||
asm volatile("ld.acquire.sys.global.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
asm volatile("ld.acquire.sys.global.b32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
} else if (mode == 1) {
|
||||
asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
asm volatile("ld.acquire.gpu.global.b32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
} else {
|
||||
asm volatile("ld.acquire.cta.global.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
asm volatile("ld.acquire.cta.global.b32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, uint32_t flag, int mode = 0) {
|
||||
__forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
|
||||
uint32_t flag, int mode = 0) {
|
||||
if (mode == 0) {
|
||||
asm volatile("st.release.sys.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
asm volatile("st.release.sys.global.b32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
} else if (mode == 1) {
|
||||
asm volatile("st.release.gpu.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
asm volatile("st.release.gpu.global.b32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
} else {
|
||||
asm volatile("st.release.cta.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
asm volatile("st.release.cta.global.b32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -11,19 +12,67 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fstream"
|
||||
#include "helper.h"
|
||||
#include "iomanip"
|
||||
#include "fstream"
|
||||
#include "iostream"
|
||||
#include "iomanip"
|
||||
#include <nvml.h>
|
||||
#include <iostream>
|
||||
#include <nvml.h>
|
||||
// #define PRINT_GPU_MEMORY
|
||||
// 函数用于获取 NVIDIA GPU 显存信息
|
||||
bool getNvidiaGPUMemoryUsage(int callLine) {
|
||||
#ifndef PRINT_GPU_MEMORY
|
||||
return true;
|
||||
#endif
|
||||
// 初始化 NVML
|
||||
nvmlReturn_t result;
|
||||
result = nvmlInit();
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to initialize NVML: " << nvmlErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
// 获取 GPU 设备数量
|
||||
unsigned int deviceCount;
|
||||
result = nvmlDeviceGetCount(&deviceCount);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get device count: " << nvmlErrorString(result) << std::endl;
|
||||
nvmlShutdown();
|
||||
return false;
|
||||
}
|
||||
// 遍历每个 GPU 设备
|
||||
for (unsigned int i = 0; i < deviceCount; ++i) {
|
||||
nvmlDevice_t device;
|
||||
result = nvmlDeviceGetHandleByIndex(i, &device);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get device handle for device " << i << ": " << nvmlErrorString(result) << std::endl;
|
||||
continue;
|
||||
}
|
||||
// 获取显存信息
|
||||
nvmlMemory_t memory;
|
||||
result = nvmlDeviceGetMemoryInfo(device, &memory);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get memory info for device " << i << ": " << nvmlErrorString(result) << std::endl;
|
||||
continue;
|
||||
}
|
||||
// 只打印一行信息并显示调用函数时的行号
|
||||
std::cout << callLine << ": GPU " << i << " - Total: " << memory.total / (1024 * 1024)
|
||||
<< " MiB, Used: " << memory.used / (1024 * 1024)
|
||||
<< " MiB, Free: " << memory.free / (1024 * 1024) << " MiB" << std::endl;
|
||||
}
|
||||
// 清理 NVML 资源
|
||||
nvmlShutdown();
|
||||
return true;
|
||||
}
|
||||
|
||||
// #define DEBUG_IPC_SENT
|
||||
// #define DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
|
||||
template <typename T>
|
||||
template<typename T>
|
||||
void sent_key_value_by_remote_ptr(
|
||||
const T* local_key_tensor_base_ptr, // gpu ptr
|
||||
const T* local_value_tensor_base_ptr, // gpu ptr
|
||||
const int32_t* local_block_ids_ptr, // cpu ptr,
|
||||
const T* local_key_tensor_base_ptr, // gpu ptr
|
||||
const T* local_value_tensor_base_ptr, // gpu ptr
|
||||
const int32_t* local_block_ids_ptr, //cpu ptr,
|
||||
const int32_t* remote_block_ids_ptr,
|
||||
const int32_t block_num,
|
||||
const int64_t block_idx_stride,
|
||||
@@ -32,170 +81,163 @@ void sent_key_value_by_remote_ptr(
|
||||
const int32_t remote_device_id,
|
||||
T* remote_key_tensor_base_ptr, // gpu ptr
|
||||
T* remote_value_tensor_base_ptr, // gpu ptr
|
||||
cudaStream_t stream) {
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const T* local_key_tensor_sent_ptr =
|
||||
local_key_tensor_base_ptr +
|
||||
local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_key_tensor_sent_ptr =
|
||||
remote_key_tensor_base_ptr +
|
||||
remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "remote_key_tensor_sent_ptr:"
|
||||
<< (int64_t)remote_key_tensor_sent_ptr
|
||||
<< " local_key_tensor_sent_ptr:"
|
||||
<< (int64_t)local_key_tensor_sent_ptr
|
||||
<< " local_device_id:" << local_device_id
|
||||
<< " remote_device_id:" << remote_device_id
|
||||
<< " block_idx_stride:" << block_idx_stride
|
||||
<< " block_size_byte:" << block_size_byte
|
||||
<< " stream: " << stream
|
||||
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<< std::endl;
|
||||
#endif
|
||||
cudaStream_t stream){
|
||||
for(int block_idx=0;block_idx < block_num; ++block_idx){
|
||||
const T* local_key_tensor_sent_ptr = local_key_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_key_tensor_sent_ptr = remote_key_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
|
||||
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
|
||||
<<" local_device_id:" << local_device_id
|
||||
<<" remote_device_id:" << remote_device_id
|
||||
<<" block_idx_stride:" << block_idx_stride
|
||||
<<" block_size_byte:" << block_size_byte
|
||||
<<" stream: " << stream
|
||||
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<<std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeer(reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
#endif
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
const T* local_value_tensor_sent_ptr =
|
||||
local_value_tensor_base_ptr +
|
||||
local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_value_tensor_sent_ptr =
|
||||
remote_value_tensor_base_ptr +
|
||||
remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "remote_value_tensor_sent_ptr:"
|
||||
<< (int64_t)remote_value_tensor_sent_ptr
|
||||
<< " local_value_tensor_sent_ptr:"
|
||||
<< (int64_t)local_value_tensor_sent_ptr
|
||||
<< " local_device_id:" << local_device_id
|
||||
<< " remote_device_id:" << remote_device_id
|
||||
<< " block_idx_stride:" << block_idx_stride
|
||||
<< " block_size_byte:" << block_size_byte
|
||||
<< " stream: " << stream
|
||||
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<< std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeer(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
#endif
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if ( err != cudaSuccess )
|
||||
{
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
const T* local_value_tensor_sent_ptr = local_value_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_value_tensor_sent_ptr = remote_value_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
|
||||
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
|
||||
<<" local_device_id:" << local_device_id
|
||||
<<" remote_device_id:" << remote_device_id
|
||||
<<" block_idx_stride:" << block_idx_stride
|
||||
<<" block_size_byte:" << block_size_byte
|
||||
<<" stream: " << stream
|
||||
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<<std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeer(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
if ( err != cudaSuccess )
|
||||
{
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
|
||||
const paddle::Tensor& local_value_tensor,
|
||||
const paddle::Tensor& local_block_ids, // cpu
|
||||
const paddle::Tensor& remote_block_ids, // cpu
|
||||
const paddle::Tensor& local_block_ids, // cpu
|
||||
const paddle::Tensor& remote_block_ids, // cpu
|
||||
const paddle::Tensor& remote_key_tensor,
|
||||
const paddle::Tensor& remote_value_tensor,
|
||||
const int& block_num,
|
||||
const int& local_device_id,
|
||||
const int& remote_device_id) {
|
||||
const int& remote_device_id,
|
||||
const int64_t& cuda_stream_raw) {
|
||||
std::vector<int64_t> cache_key_tensor_shape = local_key_tensor.shape();
|
||||
auto cuda_stream = local_key_tensor.stream();
|
||||
// const cudaStream_t cuda_stream = *(reinterpret_cast<const
|
||||
// cudaStream_t*>(&stream));
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "#### 000" << std::endl;
|
||||
#endif
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
// auto cuda_stream = local_key_tensor.stream();
|
||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
// const cudaStream_t cuda_stream = *(reinterpret_cast<const cudaStream_t*>(&stream));
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"#### 000"<<std::endl;
|
||||
#endif
|
||||
|
||||
int32_t total_block_num_local = cache_key_tensor_shape[0];
|
||||
int32_t kv_num_head_local = cache_key_tensor_shape[1];
|
||||
int32_t block_size_local = cache_key_tensor_shape[2];
|
||||
int32_t hidden_size_local = cache_key_tensor_shape[3];
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
|
||||
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
|
||||
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
|
||||
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
|
||||
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "#### 1111"
|
||||
<< " remote_key_ptr: " << remote_key_ptr
|
||||
<< " remote_value_ptr: " << remote_value_ptr << std::endl;
|
||||
#endif
|
||||
int64_t block_idx_stride =
|
||||
kv_num_head_local * block_size_local * hidden_size_local;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"#### 1111"
|
||||
<< " remote_key_ptr: "<<remote_key_ptr
|
||||
<< " remote_value_ptr: "<<remote_value_ptr<<std::endl;
|
||||
#endif
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
int64_t block_idx_stride = kv_num_head_local*block_size_local*hidden_size_local;
|
||||
auto local_key_tensor_ptr = local_key_tensor.data();
|
||||
auto local_value_tensor_ptr = local_value_tensor.data();
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "#### 2222" << std::endl;
|
||||
#endif
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"#### 2222"<<std::endl;
|
||||
#endif
|
||||
|
||||
switch (local_key_tensor.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using dataT = __nv_bfloat16;
|
||||
using dataT=__nv_bfloat16;
|
||||
// std::cout<<"#### cache type __nv_bfloat16" << std::endl;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
@@ -209,10 +251,11 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using dataT = half;
|
||||
using dataT=half;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
@@ -225,10 +268,11 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
case paddle::DataType::INT8: {
|
||||
using dataT = int8_t;
|
||||
using dataT=int8_t;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
@@ -241,10 +285,11 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
case paddle::DataType::UINT8: {
|
||||
using dataT = uint8_t;
|
||||
using dataT=uint8_t;
|
||||
// std::cout<<"#### cache type uint8" << std::endl;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
@@ -258,20 +303,33 @@ void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
}
|
||||
// using dataT=std::remove_pointer<decltype(local_block_ids_ptr)>;
|
||||
}
|
||||
|
||||
void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
|
||||
const paddle::Tensor& local_value_tensor,
|
||||
const int64_t& cuda_stream_raw) {
|
||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||
cudaStreamSynchronize(cuda_stream);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
|
||||
.Inputs({"local_key_tensor",
|
||||
"local_value_tensor",
|
||||
"local_block_ids",
|
||||
"remote_block_ids",
|
||||
"remote_key_tensor",
|
||||
"remote_value_tensor"})
|
||||
.Attrs({"block_num: int", "local_device_id: int", "remote_device_id: int"})
|
||||
.Outputs({"remote_block_ids_out"})
|
||||
.SetInplaceMap({{"remote_block_ids", "remote_block_ids_out"}})
|
||||
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
|
||||
.Attrs({ "block_num: int",
|
||||
"local_device_id: int",
|
||||
"remote_device_id: int",
|
||||
"cuda_stream_raw: int64_t"})
|
||||
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtr));
|
||||
|
||||
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
|
||||
.Inputs({"local_key_tensor", "local_value_tensor"})
|
||||
.Attrs({"cuda_stream_raw: int64_t"})
|
||||
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));
|
||||
61
custom_ops/gpu_ops/moe/deepgemm_preprocess.cu
Normal file
61
custom_ops/gpu_ops/moe/deepgemm_preprocess.cu
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void
|
||||
cuda_kernel(const scalar_t *__restrict__ topk_ids, int32_t *__restrict__ res,
|
||||
int32_t *__restrict__ res_padded, size_t numel, int num_experts) {
|
||||
|
||||
extern __shared__ int32_t tokens_per_ep[];
|
||||
|
||||
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
|
||||
tokens_per_ep[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = threadIdx.x; i < numel; i += blockDim.x) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if(expert_id >= 0) atomicAdd(&tokens_per_ep[expert_id], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
|
||||
res[i] = tokens_per_ep[i];
|
||||
res_padded[i] = (res[i] + 127) / 128 * 128;
|
||||
}
|
||||
}
|
||||
|
||||
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts) {
|
||||
|
||||
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
|
||||
|
||||
auto token_nums_per_expert = paddle::empty(
|
||||
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
|
||||
|
||||
auto stream = topk_ids.stream();
|
||||
using scalar_t = int64_t;
|
||||
|
||||
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
|
||||
topk_ids.data<scalar_t>(), token_nums_per_expert.data<int32_t>(),
|
||||
token_nums_per_expert.data<int32_t>() + num_experts, topk_ids_numel,
|
||||
num_experts);
|
||||
return token_nums_per_expert;
|
||||
}
|
||||
|
||||
|
||||
@@ -30,9 +30,9 @@ namespace cg = cooperative_groups;
|
||||
template<typename T>
|
||||
__device__ T warpReduceSum(T val){
|
||||
for(int lane_mask = 16; lane_mask > 0; lane_mask /=2){
|
||||
val += __shfl_down_sync(0xffffffff, val, lane_mask);
|
||||
val += __shfl_down_sync(0xffffffff, val, lane_mask);
|
||||
}
|
||||
return val;
|
||||
return val;
|
||||
}
|
||||
|
||||
__global__ void get_expert_token_num(
|
||||
@@ -88,7 +88,7 @@ __global__ void get_expert_token_num(
|
||||
sum = (threadIdx.x < KNWARPS) ? warp_sum[laneId] : 0;
|
||||
sum_padded = (threadIdx.x < KNWARPS) ? warp_sum[laneId + KNWARPS] : 0;
|
||||
if (warpId == 0) {
|
||||
sum = warpReduceSum<int>(sum);
|
||||
sum = warpReduceSum<int>(sum);
|
||||
sum_padded = warpReduceSum<int>(sum_padded);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -167,7 +167,7 @@ __global__ void combine_prmt_back_kernel(
|
||||
#pragma unroll
|
||||
for (int vid = 0; vid < VEC_SIZE; vid++) {
|
||||
res_vec[vid] += static_cast<T>(
|
||||
row_scale * static_cast<float>(load_vec[vid]) +
|
||||
row_scale * static_cast<float>(load_vec[vid]) +
|
||||
static_cast<float>(bias_vec[vid]));
|
||||
}
|
||||
} else {
|
||||
@@ -497,7 +497,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
place);
|
||||
auto num_experts_per_rank_tensor = GetEmptyTensor(
|
||||
{num_experts_per_rank},
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
auto expert_idx_per_token = GetEmptyTensor(
|
||||
{token_nums_this_rank}, paddle::DataType::INT64, place);
|
||||
@@ -619,7 +619,7 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
|
||||
"cumsum_idx_gpu",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({
|
||||
"token_nums_per_expert: std::vector<int>",
|
||||
"token_nums_per_expert: std::vector<int>",
|
||||
"token_nums_this_rank: int",
|
||||
"moe_quant_type: std::string"
|
||||
})
|
||||
@@ -672,18 +672,21 @@ __global__ void permute_x_fp8_kernel(const T *src_x,
|
||||
const int hidden_size_int4 = hidden_size / vec_size;
|
||||
const int hidden_size_scale = hidden_size / 128;
|
||||
const int hidden_size_scale_int4 = hidden_size_scale / scale_vec_size;
|
||||
const int token_nums_feed_to_ffn = token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK-1];
|
||||
// prmt
|
||||
for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_this_rank_padded; s_token_idx += gridDim.x) {
|
||||
if (tid == 0) {
|
||||
for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) {
|
||||
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
|
||||
const int end_idx = token_nums_per_expert_cum[i];
|
||||
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
|
||||
m_indices[s_token_idx] = i;
|
||||
break;
|
||||
}
|
||||
for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_feed_to_ffn; s_token_idx += gridDim.x) {
|
||||
|
||||
// the m_indices[s_token_idx] must be a value `i` in [0, NUM_EXPERTS_PER_RANK)
|
||||
// here we parallel wo find the `i` we want.
|
||||
for (int i = threadIdx.x; i < NUM_EXPERTS_PER_RANK; i+= blockDim.x) {
|
||||
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
|
||||
const int end_idx = token_nums_per_expert_cum[i];
|
||||
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
|
||||
if ((s_token_idx - start_idx) < token_nums_per_expert[i]) m_indices[s_token_idx] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (s_token_idx < num_rows) {
|
||||
const int64_t *topk_idx_now = topk_idx + s_token_idx * moe_topk;
|
||||
#pragma unroll
|
||||
@@ -738,7 +741,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
||||
paddle::Tensor* m_indices) {
|
||||
auto stream = input.stream();
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
// const int gridx = min(132 * 8, num_rows);
|
||||
const int gridx = 132 * 8;
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
@@ -831,8 +835,31 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 128) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else {
|
||||
PD_THROW("Not dispatching this num_experts_per_rank for EPMoeDispatchFP8Kernel");
|
||||
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -842,10 +869,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor& scale,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const std::vector<int>& token_nums_per_expert_padded,
|
||||
const int token_nums_this_rank,
|
||||
const int token_nums_this_rank_padded) {
|
||||
const paddle::Tensor& num_experts_per_rank_tensor,
|
||||
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
|
||||
const auto input_type = input.dtype();
|
||||
const int moe_topk = topk_ids.dims()[1];
|
||||
auto place = input.place();
|
||||
@@ -859,7 +884,10 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
}
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||
const int num_experts_per_rank = token_nums_per_expert.size();
|
||||
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
|
||||
|
||||
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
|
||||
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
|
||||
|
||||
auto permute_input = GetEmptyTensor(
|
||||
{token_nums_this_rank_padded, hidden_size},
|
||||
@@ -869,30 +897,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
{token_nums_this_rank_padded, hidden_size / 128},
|
||||
paddle::DataType::FLOAT32,
|
||||
place);
|
||||
auto num_experts_per_rank_tensor = GetEmptyTensor(
|
||||
{num_experts_per_rank},
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
auto num_experts_per_rank_padded_tensor = GetEmptyTensor(
|
||||
{num_experts_per_rank},
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
auto m_indices = GetEmptyTensor(
|
||||
{token_nums_this_rank_padded},
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
cudaMemcpyAsync(
|
||||
num_experts_per_rank_tensor.data<int>(),
|
||||
token_nums_per_expert.data(),
|
||||
num_experts_per_rank * sizeof(int),
|
||||
cudaMemcpyHostToDevice,
|
||||
input.stream());
|
||||
cudaMemcpyAsync(
|
||||
num_experts_per_rank_padded_tensor.data<int>(),
|
||||
token_nums_per_expert_padded.data(),
|
||||
num_experts_per_rank * sizeof(int),
|
||||
cudaMemcpyHostToDevice,
|
||||
input.stream());
|
||||
|
||||
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
|
||||
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
|
||||
@@ -908,8 +914,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
num_experts_per_rank_padded_tensor,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
-1,
|
||||
-1,
|
||||
hidden_size,
|
||||
num_experts_per_rank,
|
||||
&permute_input,
|
||||
@@ -932,61 +938,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
m_indices};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> EPMoeExpertDispatchFP8InferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& scale_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const std::vector<int>& token_nums_per_expert_padded,
|
||||
const int token_nums_this_rank,
|
||||
const int token_nums_this_rank_padded) {
|
||||
int token_rows = -1; // real token row
|
||||
int moe_topk = topk_ids_shape[1];
|
||||
if (input_shape.size() == 3) {
|
||||
token_rows = input_shape[0] * input_shape[1];
|
||||
} else {
|
||||
token_rows = input_shape[0];
|
||||
}
|
||||
const int expert_num = token_nums_per_expert.size(); // 本地专家个数
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
|
||||
return {{token_nums_this_rank_padded, hidden_size}, // x
|
||||
{token_nums_this_rank_padded, hidden_size / 128}, // scale
|
||||
{expert_num, num_rows},
|
||||
{expert_num},
|
||||
{expert_num},
|
||||
{token_nums_this_rank_padded},
|
||||
{num_rows, expert_num},
|
||||
{expert_num},
|
||||
{token_nums_this_rank_padded}}; // dst_idx per expert
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> EPMoeExpertDispatchFP8InferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& scale_dtype,
|
||||
const paddle::DataType& topk_ids_dtype,
|
||||
const paddle::DataType& topk_weights_dtype,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const std::vector<int>& token_nums_per_expert_padded,
|
||||
const int token_nums_this_rank,
|
||||
const int token_nums_this_rank_padded) {
|
||||
return {input_dtype,
|
||||
paddle::DataType::FLOAT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT64,
|
||||
paddle::DataType::INT64,
|
||||
paddle::DataType::FLOAT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
|
||||
.Inputs({"input", "scale", "topk_ids", "topk_weights"})
|
||||
.Inputs({"input", "scale", "topk_ids", "topk_weights", "num_experts_per_rank_tensor", "num_experts_per_rank_padded_tensor"})
|
||||
.Outputs({"permute_input",
|
||||
"permute_scale",
|
||||
"permute_indices_per_token",
|
||||
@@ -996,12 +949,4 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
|
||||
"dst_indices",
|
||||
"cumsum_idx_gpu",
|
||||
"m_indices"})
|
||||
.Attrs({
|
||||
"token_nums_per_expert: std::vector<int>",
|
||||
"token_nums_per_expert_padded: std::vector<int>",
|
||||
"token_nums_this_rank: int",
|
||||
"token_nums_this_rank_padded: int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchFP8InferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchFP8InferDtype));
|
||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
#include "helper.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
|
||||
@@ -71,9 +72,9 @@ void FusedMoeKernel(const paddle::Tensor& input,
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, DataType_>();
|
||||
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, uint8_t>();
|
||||
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::uint4b_t>();
|
||||
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
|
||||
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
|
||||
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
|
||||
|
||||
using NvType = typename traits_::DataType;
|
||||
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,
|
||||
|
||||
@@ -15,6 +15,7 @@ limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "moe/fused_moe_op.h"
|
||||
|
||||
@@ -53,11 +54,16 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
|
||||
template <typename T, typename NvType> class MoeHelper {
|
||||
public:
|
||||
MoeHelper(const std::string gemm_method,
|
||||
MoeGemmRunner<NvType, NvType> *fp16_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, uint8_t> *int8_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, cutlass::uint4b_t> *int4_moe_gemm_runner,
|
||||
int layernum = 0)
|
||||
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
||||
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
||||
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
||||
|
||||
MoeHelper(
|
||||
const std::string gemm_method,
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
||||
int layernum = 0)
|
||||
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
||||
int8_moe_gemm_runner_(int8_moe_gemm_runner),
|
||||
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
|
||||
@@ -254,6 +260,7 @@ public:
|
||||
total_rows_before_expert_, stream);
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
typename Int8Traits::Arguments ffn1_quant_args;
|
||||
int8_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const uint8_t *>(ffn1_weight->data<int8_t>()),
|
||||
@@ -262,8 +269,9 @@ public:
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
"none", stream);
|
||||
ffn1_quant_args, "none", stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments ffn1_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
@@ -273,8 +281,9 @@ public:
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
"none", stream);
|
||||
ffn1_quant_args, "none", stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments ffn1_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const NvType *>(ffn1_weight->data<T>()), nullptr,
|
||||
@@ -282,7 +291,7 @@ public:
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
"none", stream);
|
||||
ffn1_quant_args, "none", stream);
|
||||
}
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
@@ -295,6 +304,7 @@ public:
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
typename Int8Traits::Arguments ffn2_quant_args;
|
||||
int8_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const uint8_t *>(ffn2_weight->data<int8_t>()),
|
||||
@@ -302,8 +312,9 @@ public:
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, stream);
|
||||
num_experts, ffn2_quant_args, stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments ffn2_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
@@ -312,15 +323,16 @@ public:
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, stream);
|
||||
num_experts, ffn2_quant_args, stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments ffn2_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const NvType *>(ffn2_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, stream);
|
||||
num_experts, ffn2_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
@@ -343,9 +355,9 @@ public:
|
||||
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
MoeGemmRunner<NvType, NvType> *fp16_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, uint8_t> *int8_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, cutlass::uint4b_t> *int4_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner_;
|
||||
int layernum_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
|
||||
@@ -17,11 +17,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
#include "moe/fused_moe_imp_op.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "moe/fused_moe_imp_op.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@@ -37,8 +37,8 @@
|
||||
namespace phi {
|
||||
|
||||
struct GpuLaunchConfig {
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
@@ -57,13 +57,30 @@ inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
return config;
|
||||
}
|
||||
|
||||
constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
|
||||
template <class T, class U>
|
||||
__host__ __device__ constexpr static U arrayConvert(T const& input)
|
||||
{
|
||||
using Type = typename U::Element;
|
||||
static_assert(T::kElements == U::kElements);
|
||||
U u;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < U::kElements; i++)
|
||||
{
|
||||
u[i] = static_cast<Type>(input[i]);
|
||||
}
|
||||
return u;
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing
|
||||
// the output in the softmax kernel when we extend this module to support
|
||||
// expert-choice routing.
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void group_moe_softmax(const T *input, T *output, T *softmax_max_prob,
|
||||
void group_moe_softmax(const T* input,
|
||||
T* output,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_cols,
|
||||
const int64_t softmax_num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
@@ -82,6 +99,7 @@ __launch_bounds__(TPB) __global__
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
@@ -133,11 +151,14 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_top_k(const T *inputs_after_softmax, T *output, IdxT *indices,
|
||||
int *source_rows, T *softmax_max_prob,
|
||||
const int64_t num_experts, const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -155,7 +176,7 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
@@ -188,9 +209,10 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_softmax(const T *input, T *output, const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
@@ -240,10 +262,14 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_top_k(const T *inputs_after_softmax, const T *bias, T *output,
|
||||
IdxT *indices, int *source_rows, const int64_t num_experts,
|
||||
const int64_t k, const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -261,14 +287,13 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
|
||||
: inputs_after_softmax[idx];
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
|
||||
@@ -285,9 +310,7 @@ __launch_bounds__(TPB) __global__
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] =
|
||||
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
|
||||
: result_kvp.value;
|
||||
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
@@ -296,11 +319,14 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_softmax_top_k_fused(const T *input, const T *bias, T *output,
|
||||
IdxT *indices, int *source_rows,
|
||||
const int64_t num_experts, const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
// softmax
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -313,12 +339,11 @@ __launch_bounds__(TPB) __global__
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||
const int64_t idx = thread_row_offset + threadIdx.x;
|
||||
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||
|
||||
cub::Sum sum;
|
||||
|
||||
float threadData =
|
||||
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
|
||||
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -335,10 +360,10 @@ __launch_bounds__(TPB) __global__
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
T val = T(threadDataExp * normalizing_factor);
|
||||
|
||||
// top_k
|
||||
// top_k
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||
@@ -348,11 +373,11 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
cub_kvp inp_kvp;
|
||||
int expert = threadIdx.x;
|
||||
int expert = threadIdx.x;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||
|
||||
@@ -370,8 +395,7 @@ __launch_bounds__(TPB) __global__
|
||||
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
output[cur_idx] =
|
||||
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
indices[cur_idx] = result_kvp.key;
|
||||
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||
}
|
||||
@@ -380,11 +404,14 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moe_top_k_normed(const T *inputs_after_softmax, const T *bias,
|
||||
T *output, IdxT *indices, int *source_rows,
|
||||
const int64_t num_experts, const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -403,18 +430,17 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
T *row_outputs = reinterpret_cast<T *>(smem);
|
||||
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert]
|
||||
: inputs_after_softmax[idx];
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
@@ -431,14 +457,11 @@ __launch_bounds__(TPB) __global__
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
// output[idx] = bias ? inputs_after_softmax[thread_read_offset +
|
||||
// result_kvp.key]: result_kvp.value;
|
||||
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
T row_out =
|
||||
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
|
||||
: result_kvp.value;
|
||||
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
@@ -453,10 +476,16 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
|
||||
const T *input, const T *bias, T *output, IdxT *indices, int *source_rows,
|
||||
const int64_t num_experts, const int64_t k, const int64_t num_rows) {
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
// softmax
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@@ -469,12 +498,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||
const int64_t idx = thread_row_offset + threadIdx.x;
|
||||
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||
|
||||
cub::Sum sum;
|
||||
|
||||
float threadData =
|
||||
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
|
||||
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -490,12 +518,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
T val = T(threadDataExp * normalizing_factor);
|
||||
|
||||
// top_k
|
||||
// top_k
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||
@@ -505,15 +533,15 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
|
||||
|
||||
T weight_sum = static_cast<T>(0);
|
||||
extern __shared__ char smem[];
|
||||
T *row_outputs = reinterpret_cast<T *>(smem);
|
||||
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
cub_kvp inp_kvp;
|
||||
int expert = threadIdx.x;
|
||||
int expert = threadIdx.x;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||
|
||||
@@ -532,8 +560,7 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
|
||||
T row_out =
|
||||
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
|
||||
@@ -665,11 +692,19 @@ __launch_bounds__(TPB) __global__ void moe_redundant_top_k_normed(const T* input
|
||||
k.
|
||||
*/
|
||||
|
||||
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA,
|
||||
int BYTES_PER_LDG, typename IdxT = int>
|
||||
__launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
|
||||
void topk_gating_softmax(const T *input, T *output, const int64_t num_rows,
|
||||
IdxT *indices, int *source_rows, const int64_t k) {
|
||||
template <typename T,
|
||||
int VPT,
|
||||
int NUM_EXPERTS,
|
||||
int WARPS_PER_CTA,
|
||||
int BYTES_PER_LDG,
|
||||
typename IdxT = int>
|
||||
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
void topk_gating_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_rows,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t k) {
|
||||
// We begin by enforcing compile time assertions and setting up compile time
|
||||
// constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
@@ -722,19 +757,18 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows)
|
||||
return;
|
||||
if (thread_row >= num_rows) return;
|
||||
const bool should_process_row = true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each
|
||||
// thread jumps to the start of the row it will read.
|
||||
const T *thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the
|
||||
// first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const T *thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the
|
||||
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
|
||||
@@ -743,10 +777,10 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
cutlass::Array<T, VPT> row_chunk_input;
|
||||
AccessType *row_chunk_vec_ptr =
|
||||
reinterpret_cast<AccessType *>(&row_chunk_input);
|
||||
const AccessType *vec_thread_read_ptr =
|
||||
reinterpret_cast<const AccessType *>(thread_read_ptr);
|
||||
AccessType* row_chunk_vec_ptr =
|
||||
reinterpret_cast<AccessType*>(&row_chunk_input);
|
||||
const AccessType* vec_thread_read_ptr =
|
||||
reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
@@ -771,8 +805,9 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
|
||||
// threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask,
|
||||
THREADS_PER_ROW));
|
||||
thread_max =
|
||||
max(thread_max,
|
||||
__shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
@@ -885,7 +920,8 @@ __launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__
|
||||
namespace detail {
|
||||
// Constructs some constants needed to partition the work across threads at
|
||||
// compile time.
|
||||
template <typename T, int EXPERTS, int BYTES_PER_LDG> struct TopkConstants {
|
||||
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants {
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
|
||||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
|
||||
@@ -896,14 +932,17 @@ template <typename T, int EXPERTS, int BYTES_PER_LDG> struct TopkConstants {
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int>
|
||||
void topk_gating_softmax_launcher_helper(const T *input, T *output,
|
||||
IdxT *indices, int *source_row,
|
||||
void topk_gating_softmax_launcher_helper(const T* input,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k, cudaStream_t stream) {
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
static constexpr uint64_t MAX_BYTES_PER_LDG = 16;
|
||||
static constexpr int BYTES_PER_LDG =
|
||||
std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
||||
@@ -915,46 +954,52 @@ void topk_gating_softmax_launcher_helper(const T *input, T *output,
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
|
||||
<<<num_blocks, block_dim, 0, stream>>>(input, output, num_rows, indices,
|
||||
source_row, k);
|
||||
<<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, output, num_rows, indices, source_row, k);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
struct topk_gating_softmax_kernelLauncher {
|
||||
static void run(const T *input, const T *gating_correction_bias, T *output,
|
||||
T *softmax, IdxT *indices, int *source_row,
|
||||
T *softmax_max_prob, const int64_t num_rows,
|
||||
const int64_t num_experts, const int64_t k,
|
||||
const bool group_moe, cudaStream_t stream,
|
||||
const bool topk_only_mode = false) {
|
||||
if (topk_only_mode) {
|
||||
static constexpr int TPB = 256;
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, gating_correction_bias, output, indices, source_row,
|
||||
num_experts, k, num_rows);
|
||||
return;
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
struct topk_gating_softmax_kernelLauncher{
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
static void run(const T* input,
|
||||
const T* gating_correction_bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
IdxT* indices,
|
||||
int* source_row,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const bool group_moe,
|
||||
cudaStream_t stream,
|
||||
const bool topk_only_mode = false) {
|
||||
if (topk_only_mode) {
|
||||
static constexpr int TPB = 256;
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
|
||||
return;
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
if (gating_correction_bias != nullptr)
|
||||
tem_num_experts = 0;
|
||||
switch (tem_num_experts) {
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||
int64_t tem_num_experts = num_experts;
|
||||
if(gating_correction_bias != nullptr) tem_num_experts = 0;
|
||||
switch (tem_num_experts) {
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||
|
||||
default: {
|
||||
static constexpr int TPB = 256;
|
||||
@@ -964,24 +1009,40 @@ struct topk_gating_softmax_kernelLauncher {
|
||||
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
|
||||
group_moe_softmax<T, TPB>
|
||||
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, softmax_max_prob, group_experts,
|
||||
input,
|
||||
softmax,
|
||||
softmax_max_prob,
|
||||
group_experts,
|
||||
softmax_num_rows);
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
softmax, output, indices, source_row, softmax_max_prob, num_experts,
|
||||
k, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
softmax_max_prob,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
softmax, gating_correction_bias, output, indices, source_row,
|
||||
num_experts, k, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
gating_correction_bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
|
||||
@@ -999,13 +1060,18 @@ struct topk_gating_softmax_kernelLauncher {
|
||||
// to row 0 in the original matrix. Thus, to know where to read in the source
|
||||
// matrix, we simply take the modulus of the expanded index.
|
||||
|
||||
template <typename T, int VecSize, typename OutT = T>
|
||||
template <typename T, int VecSize, typename OutT=T>
|
||||
__global__ void initialize_moe_routing_kernel(
|
||||
const T *unpermuted_input, OutT *permuted_output,
|
||||
const int *expanded_dest_row_to_expanded_source_row,
|
||||
const int *expert_idx_per_token, const float *w4a8_in_scale,
|
||||
int *expanded_source_row_to_expanded_dest_row, const int64_t num_rows,
|
||||
const int64_t active_rows, const int64_t cols, const int64_t num_rows_k) {
|
||||
const T* unpermuted_input,
|
||||
OutT* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
const int *expert_idx_per_token,
|
||||
const float *w4a8_in_scale,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t active_rows,
|
||||
const int64_t cols,
|
||||
const int64_t num_rows_k) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
@@ -1015,24 +1081,21 @@ __global__ void initialize_moe_routing_kernel(
|
||||
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (expanded_dest_row >= num_rows_k)
|
||||
return;
|
||||
if (expanded_dest_row >= num_rows_k) return;
|
||||
const int expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
if (threadIdx.x == 0) {
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
expanded_dest_row;
|
||||
}
|
||||
|
||||
|
||||
if (expanded_dest_row < active_rows) {
|
||||
|
||||
const int expert_idx = expert_idx_per_token[expanded_dest_row];
|
||||
const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1;
|
||||
|
||||
// Duplicate and permute rows
|
||||
const int source_row = expanded_source_row % num_rows;
|
||||
|
||||
const T *source_row_ptr = unpermuted_input + source_row * cols;
|
||||
const T* source_row_ptr = unpermuted_input + source_row * cols;
|
||||
OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols;
|
||||
|
||||
for (int tid = threadIdx.x * VecSize; tid < cols;
|
||||
@@ -1061,37 +1124,56 @@ __global__ void initialize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T, typename OutT = T>
|
||||
struct initialize_moe_routing_kernelLauncher {
|
||||
static void run(const T *unpermuted_input, OutT *permuted_output,
|
||||
const int *expanded_dest_row_to_expanded_source_row,
|
||||
const int *expert_idx_per_token, const float *w4a8_in_scale,
|
||||
int *expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows, const int64_t active_rows,
|
||||
const int64_t cols, const int64_t k, cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
constexpr int max_pack_size = 16 / sizeof(T);
|
||||
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
|
||||
if (cols % max_pack_size == 0) {
|
||||
initialize_moe_routing_kernel<T, max_pack_size, OutT>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input, permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row, expert_idx_per_token,
|
||||
w4a8_in_scale, expanded_source_row_to_expanded_dest_row, num_rows,
|
||||
k * active_rows, cols, num_rows * k);
|
||||
} else {
|
||||
initialize_moe_routing_kernel<T, 1, OutT>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input, permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row, expert_idx_per_token,
|
||||
w4a8_in_scale, expanded_source_row_to_expanded_dest_row, num_rows,
|
||||
k * active_rows, cols, num_rows * k);
|
||||
}
|
||||
struct initialize_moe_routing_kernelLauncher{
|
||||
|
||||
static void run(
|
||||
const T* unpermuted_input,
|
||||
OutT* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
const int *expert_idx_per_token,
|
||||
const float *w4a8_in_scale,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t active_rows,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
constexpr int max_pack_size = 16 / sizeof(T);
|
||||
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
|
||||
if (cols % max_pack_size == 0) {
|
||||
initialize_moe_routing_kernel<T, max_pack_size>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input,
|
||||
permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expert_idx_per_token,
|
||||
w4a8_in_scale,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
k * active_rows,
|
||||
cols,
|
||||
num_rows * k);
|
||||
} else {
|
||||
initialize_moe_routing_kernel<T, 1>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input,
|
||||
permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expert_idx_per_token,
|
||||
w4a8_in_scale,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
k * active_rows,
|
||||
cols,
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
__device__ inline int find_total_elts_leq_target(int *sorted_indices,
|
||||
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
|
||||
const int64_t arr_length,
|
||||
const int64_t target) {
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
@@ -1108,10 +1190,10 @@ __device__ inline int find_total_elts_leq_target(int *sorted_indices,
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int *sorted_indices,
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t total_indices,
|
||||
const int64_t num_experts,
|
||||
int64_t *total_rows_before_expert,
|
||||
int64_t* total_rows_before_expert,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
@@ -1119,72 +1201,117 @@ void compute_total_rows_before_expert(int *sorted_indices,
|
||||
// performs the final skip connection.
|
||||
template <typename T, int RESIDUAL_NUM>
|
||||
__global__ void finalize_moe_routing_kernel(
|
||||
const T *expanded_permuted_rows, T *reduced_unpermuted_output,
|
||||
const T *bias, const float *scales,
|
||||
const int *expanded_source_row_to_expanded_dest_row,
|
||||
const int *expert_for_source_row, const int64_t cols, const int64_t k,
|
||||
const int64_t compute_bias, const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, const int64_t num_rows) {
|
||||
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
// const int original_row = blockIdx.x;
|
||||
// const int num_rows = gridDim.x;
|
||||
if (original_row >= num_rows)
|
||||
return;
|
||||
T *reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
const int64_t compute_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int64_t num_rows) {
|
||||
const int original_row = blockIdx.x;
|
||||
auto const offset = original_row * cols;
|
||||
|
||||
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
|
||||
T thread_output{0.f};
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int expanded_original_row = original_row + k_idx * num_rows;
|
||||
const int expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
T* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||
constexpr int64_t FINALIZE_ELEM_PER_THREAD
|
||||
= 128 / cutlass::sizeof_bits<T>::value;
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
|
||||
int64_t const num_elems_in_col = cols / FINALIZE_ELEM_PER_THREAD;
|
||||
|
||||
const int64_t k_offset = original_row * k + k_idx;
|
||||
const float row_scale = scales[k_offset];
|
||||
row_rescale = row_rescale + row_scale;
|
||||
using BiasElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
using OutputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
|
||||
using SharedOutputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
|
||||
const T *expanded_permuted_rows_row_ptr =
|
||||
expanded_permuted_rows + expanded_permuted_row * cols;
|
||||
auto const* bias_v = reinterpret_cast<BiasElem const*>(bias);
|
||||
auto const* expanded_permuted_rows_v = reinterpret_cast<InputElem const*>(expanded_permuted_rows);
|
||||
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
|
||||
|
||||
const int expert_idx = expert_for_source_row[k_offset];
|
||||
const T *bias_ptr = bias ? bias + expert_idx * cols : nullptr;
|
||||
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
|
||||
#pragma unroll
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
|
||||
{
|
||||
ComputeElem thread_output;
|
||||
thread_output.fill(0);
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
const float row_scale = scales[k_offset];
|
||||
row_rescale = row_rescale + row_scale;
|
||||
|
||||
thread_output =
|
||||
static_cast<float>(thread_output) +
|
||||
row_scale * static_cast<float>(
|
||||
expanded_permuted_rows_row_ptr[tid] +
|
||||
bias_value *
|
||||
static_cast<T>(static_cast<float>(compute_bias)));
|
||||
}
|
||||
auto const* expanded_permuted_rows_row_ptr
|
||||
= expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
|
||||
|
||||
thread_output = static_cast<float>(thread_output) /
|
||||
(norm_topk_prob ? row_rescale : 1.0f) *
|
||||
routed_scaling_factor;
|
||||
reduced_row_ptr[tid] = thread_output;
|
||||
int const expert_idx = expert_for_source_row[k_offset];
|
||||
auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col;
|
||||
|
||||
ComputeElem bias_value;
|
||||
if (bias)
|
||||
{
|
||||
bias_value = arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
|
||||
}
|
||||
else
|
||||
{
|
||||
bias_value.fill(0);
|
||||
}
|
||||
|
||||
ComputeElem expert_result
|
||||
= arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
|
||||
|
||||
thread_output = thread_output + row_scale * (expert_result + bias_value);
|
||||
|
||||
|
||||
}
|
||||
for (auto& elem : thread_output)
|
||||
{
|
||||
elem = elem / (norm_topk_prob ? row_rescale : 1.0f) * routed_scaling_factor;
|
||||
}
|
||||
OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
|
||||
reduced_row_ptr_v[elem_index] = output_elem;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> struct finalize_moe_routing_kernelLauncher {
|
||||
template <typename T>
|
||||
struct finalize_moe_routing_kernelLauncher{
|
||||
static void run(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
const int64_t compute_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
cudaStream_t stream) {
|
||||
const int blocks = num_rows;
|
||||
const int threads = FINALIZE_THREADS_PER_BLOCK;
|
||||
|
||||
static void run(const T *expanded_permuted_rows, T *reduced_unpermuted_output,
|
||||
const T *bias, const float *scales,
|
||||
const int *expanded_source_row_to_expanded_dest_row,
|
||||
const int *expert_for_source_row, const int64_t num_rows,
|
||||
const int64_t cols, const int64_t k,
|
||||
const int64_t compute_bias, const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
|
||||
finalize_moe_routing_kernel<T, 1>
|
||||
<<<config_final.block_per_grid, threads, 0, stream>>>(
|
||||
expanded_permuted_rows, reduced_unpermuted_output, bias, scales,
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
cols, k, compute_bias, norm_topk_prob, routed_scaling_factor,
|
||||
finalize_moe_routing_kernel<T, 1>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
expanded_permuted_rows,
|
||||
reduced_unpermuted_output,
|
||||
bias,
|
||||
scales,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
cols,
|
||||
k,
|
||||
compute_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
368
custom_ops/gpu_ops/moe/gptq_marlin_repack.cu
Normal file
368
custom_ops/gpu_ops/moe/gptq_marlin_repack.cu
Normal file
@@ -0,0 +1,368 @@
|
||||
#include "moe_wna16_marlin_utils/marlin.cuh"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "helper.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
|
||||
int4* sh_perm_ptr = sh;
|
||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
|
||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
|
||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr =
|
||||
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(
|
||||
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
||||
first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, MARLIN_NAMESPACE_NAME::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddle::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
PADDLE_ENFORCE(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_k_size == 0,
|
||||
"size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
MARLIN_NAMESPACE_NAME::tile_k_size);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
size_n % MARLIN_NAMESPACE_NAME::tile_n_size == 0,
|
||||
"size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
MARLIN_NAMESPACE_NAME::tile_n_size);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
// shape checks
|
||||
PADDLE_ENFORCE(
|
||||
(size_k / pack_factor) == b_q_weight.dims()[0],
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.dims()[0]);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.dims()[1] == size_n,
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.dims()[1],
|
||||
", expected size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.is_gpu(),
|
||||
"b_q_weight is not on GPU");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.is_contiguous(),
|
||||
"b_q_weight is not contiguous");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.dtype() == phi::DataType::INT32,
|
||||
"b_q_weight type is not kInt");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
perm.is_gpu(),
|
||||
"perm is not on GPU");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
perm.is_contiguous(),
|
||||
"perm is not contiguous");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
perm.dtype() == phi::DataType::INT32,
|
||||
"perm type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
// const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
paddle::Tensor out = paddle::empty(
|
||||
{size_k / MARLIN_NAMESPACE_NAME::tile_size, size_n * MARLIN_NAMESPACE_NAME::tile_size / pack_factor},
|
||||
b_q_weight.dtype(),
|
||||
b_q_weight.place());
|
||||
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.dims()[0] != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const*>(b_q_weight.data());
|
||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.place().GetDeviceId();
|
||||
cudaStream_t stream = b_q_weight.stream();
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
// TORCH_CHECK(max_shared_mem > 0);
|
||||
PADDLE_ENFORCE(
|
||||
max_shared_mem > 0,
|
||||
"max_shared_mem must be > 0. Got = ", max_shared_mem);
|
||||
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
// TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
// ", has_perm = ", has_perm);
|
||||
PADDLE_ENFORCE(
|
||||
false,
|
||||
"Unsupported repack config: num_bits = ", num_bits,
|
||||
", has_perm = ", has_perm);
|
||||
|
||||
}
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(gptq_marlin_repack)
|
||||
.Inputs({"b_q_weight", "perm"})
|
||||
.Outputs({"out"})
|
||||
.Attrs({
|
||||
"size_k: int64_t",
|
||||
"size_n: int64_t",
|
||||
"num_bits: int64_t"
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(gptq_marlin_repack));
|
||||
@@ -211,12 +211,14 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;
|
||||
|
||||
return {{moe_topk * num_rows, hidden_size},
|
||||
return {{permuted_rows, hidden_size},
|
||||
{expert_num},
|
||||
{moe_topk, num_rows},
|
||||
{num_rows, moe_topk},
|
||||
{num_rows, moe_topk}};
|
||||
{num_rows, moe_topk},
|
||||
{permuted_rows}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
@@ -225,7 +227,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
||||
const paddle::optional<paddle::DataType> &bias_type,
|
||||
const int moe_topk) {
|
||||
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32, paddle::DataType::INT32};
|
||||
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -281,7 +283,8 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||
paddle::Optional("gating_correction_bias"),
|
||||
paddle::Optional("w4a8_in_scale")})
|
||||
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token", "topk_weight", "topk_idx"})
|
||||
"permute_indices_per_token", "topk_weight", "topk_idx",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
|
||||
@@ -44,9 +44,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
auto place = permute_input.place();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, DataType_>();
|
||||
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, uint8_t>();
|
||||
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::uint4b_t>();
|
||||
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
|
||||
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
|
||||
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
|
||||
auto w4a8_moe_gemm_runner = W4A8MoeGemmRunner<DataType_, int8_t, cutlass::uint4b_t>();
|
||||
|
||||
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
|
||||
@@ -109,6 +109,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const int64_t tune_total_rows = expanded_active_expert_rows;
|
||||
|
||||
if (quant_method == "weight_only_int8") {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||
int8_moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
|
||||
reinterpret_cast<const uint8_t*>(ffn1_weight.data<int8_t>()),
|
||||
@@ -123,9 +124,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
quant_args,
|
||||
"none",
|
||||
stream);
|
||||
} else if (quant_method == "weight_only_int4") {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
|
||||
int4_moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
|
||||
reinterpret_cast<const cutlass::uint4b_t*>(
|
||||
@@ -141,6 +144,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
quant_args,
|
||||
"none",
|
||||
stream);
|
||||
} else if (quant_method == "w4a8") {
|
||||
@@ -165,6 +169,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
num_experts,
|
||||
stream);
|
||||
} else {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||
fp16_moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
|
||||
reinterpret_cast<const NvType*>(ffn1_weight.data<data_t>()),
|
||||
@@ -177,6 +182,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
quant_args,
|
||||
"none",
|
||||
stream);
|
||||
}
|
||||
@@ -190,6 +196,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
if (quant_method == "weight_only_int8") {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||
int8_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out),
|
||||
reinterpret_cast<const uint8_t*>(ffn2_weight.data<int8_t>()),
|
||||
@@ -203,9 +210,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
quant_args,
|
||||
stream);
|
||||
|
||||
} else if (quant_method == "weight_only_int4") {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
|
||||
int4_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out),
|
||||
reinterpret_cast<const cutlass::uint4b_t*>(
|
||||
@@ -220,6 +229,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
quant_args,
|
||||
stream);
|
||||
} else if (quant_method == "w4a8") {
|
||||
data_t *ffn2_shift = nullptr;
|
||||
@@ -262,6 +272,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
num_experts,
|
||||
stream);
|
||||
} else {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||
fp16_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out),
|
||||
reinterpret_cast<const NvType*>(ffn2_weight.data<data_t>()),
|
||||
@@ -273,6 +284,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
quant_args,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
@@ -288,6 +300,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
||||
|
||||
cudaCheckError();
|
||||
const auto t_type = quant_method == "w4a8" ? ffn1_scale.get().dtype() : permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
|
||||
@@ -381,14 +395,14 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
|
||||
/**
|
||||
* @brief Mixture of Experts (MoE) Feed-Forward Network Operator
|
||||
*
|
||||
*
|
||||
* This operator performs the expert computation in MoE architecture, including:
|
||||
* 1. First linear transformation (FFN1) with optional quantization
|
||||
* 2. SwiGLU activation function
|
||||
* 3. Second linear transformation (FFN2) with optional quantization
|
||||
*
|
||||
*
|
||||
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
|
||||
*
|
||||
*
|
||||
* Inputs:
|
||||
* - permute_input: Permuted input tensor organized by expert
|
||||
* Shape: [total_tokens * top_k, hidden_size]
|
||||
@@ -416,18 +430,18 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* - expert_idx_per_token: Optional expert indices per token (w4a8 only)
|
||||
* Shape: [total_tokens]
|
||||
* dtype: int64
|
||||
*
|
||||
*
|
||||
* Outputs:
|
||||
* - output_tensor: Output tensor after MoE FFN computation
|
||||
* Shape: Same as permute_input
|
||||
* dtype: Same as input (or ffn1_scale dtype for w4a8)
|
||||
*
|
||||
*
|
||||
* Attributes:
|
||||
* - quant_method: Quantization method to use
|
||||
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
* Affects activation function implementation
|
||||
*
|
||||
*
|
||||
* Note:
|
||||
* - w4a8 mode requires additional workspace memory allocation
|
||||
* - Low latency mode uses specialized grouped SwiGLU implementation
|
||||
|
||||
377
custom_ops/gpu_ops/moe/moe_ffn_wint2.cu
Normal file
377
custom_ops/gpu_ops/moe/moe_ffn_wint2.cu
Normal file
@@ -0,0 +1,377 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "group_swiglu_with_masked.h"
|
||||
#include "helper.h"
|
||||
#include "moe/fast_hardamard_kernel.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
|
||||
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
|
||||
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::Tensor* ffn1_bias,
|
||||
const paddle::Tensor* ffn1_super_scale,
|
||||
const paddle::Tensor* ffn2_super_scale,
|
||||
const paddle::Tensor* ffn1_local_scale,
|
||||
const paddle::Tensor* ffn1_code_scale,
|
||||
const paddle::Tensor* ffn1_code_zp,
|
||||
const paddle::Tensor* ffn2_local_scale,
|
||||
const paddle::Tensor* ffn2_code_scale,
|
||||
const paddle::Tensor* ffn2_code_zp,
|
||||
paddle::Tensor fc1_out,
|
||||
paddle::Tensor ffn_out,
|
||||
const int64_t total_rows_in_ll_else_minus1,
|
||||
const int64_t actual_total_rows,
|
||||
const int64_t inter_size,
|
||||
const int64_t hidden_size,
|
||||
const int num_experts,
|
||||
bool used_in_ep_low_latency) {
|
||||
using namespace phi;
|
||||
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
|
||||
using WeightType = typename WeightOnlyTraits::WeightType;
|
||||
|
||||
typename WeightOnlyTraits::Arguments ffn1_quant_args;
|
||||
typename WeightOnlyTraits::Arguments ffn2_quant_args;
|
||||
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data<uint8_t>();
|
||||
ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data<float>();
|
||||
ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data<float>();
|
||||
ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data<uint8_t>();
|
||||
ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data<float>();
|
||||
ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data<float>();
|
||||
}
|
||||
|
||||
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
|
||||
reinterpret_cast<const WeightType*>(ffn1_weight.data<WeightSavedT>()),
|
||||
reinterpret_cast<const NvType*>(ffn1_super_scale ? ffn1_super_scale->data<DataT>() : nullptr),
|
||||
reinterpret_cast<const NvType*>(ffn1_bias ? ffn1_bias->data<DataT>() : nullptr),
|
||||
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
total_rows_in_ll_else_minus1,
|
||||
actual_total_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
ffn1_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
|
||||
paddle::Tensor act_out;
|
||||
if (used_in_ep_low_latency) {
|
||||
act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum);
|
||||
} else {
|
||||
act_out = paddle::experimental::swiglu(fc1_out, nullptr);
|
||||
}
|
||||
|
||||
moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
|
||||
reinterpret_cast<const WeightType*>(ffn2_weight.data<WeightSavedT>()),
|
||||
reinterpret_cast<const NvType*>(ffn2_super_scale ? ffn2_super_scale->data<DataT>() : nullptr),
|
||||
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
total_rows_in_ll_else_minus1,
|
||||
actual_total_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
ffn2_quant_args,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency) {
|
||||
using namespace phi;
|
||||
using data_t = typename PDTraits<T>::data_t;
|
||||
using NvType = typename PDTraits<T>::DataType;
|
||||
|
||||
auto place = permute_input.place();
|
||||
|
||||
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
|
||||
assert(ffn1_weight.dims().size() == 3);
|
||||
|
||||
const int num_experts = ffn1_weight.dims()[0];
|
||||
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
|
||||
|
||||
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
|
||||
|
||||
const int64_t inter_size = inter_dim * 4;
|
||||
|
||||
int num_experts_ = num_experts;
|
||||
int num_max_tokens_per_expert = 0;
|
||||
int expanded_active_expert_rows = 0;
|
||||
|
||||
paddle::Tensor fc1_out_tensor;
|
||||
if (permute_input.dims().size() == 3) {
|
||||
num_experts_ = permute_input.dims()[0];
|
||||
assert(num_experts == num_experts_);
|
||||
|
||||
num_max_tokens_per_expert = permute_input.dims()[1];
|
||||
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
|
||||
fc1_out_tensor = GetEmptyTensor(
|
||||
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
|
||||
} else {
|
||||
expanded_active_expert_rows = permute_input.dims()[0];
|
||||
fc1_out_tensor = GetEmptyTensor(
|
||||
{expanded_active_expert_rows, inter_size}, T, place);
|
||||
}
|
||||
|
||||
// This is a trick.
|
||||
// expanded_active_expert_rows is not needed in variable group gemm.
|
||||
// but is needed in accommodating deepep low latency mode
|
||||
const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1;
|
||||
|
||||
// When we tune the optimal configuration, we need the actual total_rows.
|
||||
const int64_t actual_total_rows = expanded_active_expert_rows;
|
||||
|
||||
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
|
||||
permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
const_cast<paddle::Tensor*>(ffn1_bias.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_local_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_code_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_code_zp.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_local_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_code_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_code_zp.get_ptr()),
|
||||
fc1_out_tensor,
|
||||
ffn_out,
|
||||
total_rows_in_ll_else_minus1,
|
||||
actual_total_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
used_in_ep_low_latency);
|
||||
}
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
const auto dtype = permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, dtype);
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn1_local_scale,
|
||||
ffn1_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_local_scale,
|
||||
ffn2_code_scale,
|
||||
ffn2_code_zp,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn1_local_scale,
|
||||
ffn1_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_local_scale,
|
||||
ffn2_code_scale,
|
||||
ffn2_code_zp,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
}
|
||||
return ffn_out;
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFNWint2(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
return {MoeExpertFFNWint2Func(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn1_local_scale,
|
||||
ffn1_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_local_scale,
|
||||
ffn2_code_scale,
|
||||
ffn2_code_zp,
|
||||
used_in_ep_low_latency)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_local_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_code_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_code_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_local_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_code_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_code_zp_shape,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
|
||||
const paddle::DataType &permute_input_dtype,
|
||||
const paddle::DataType &tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType &ffn1_weight_dtype,
|
||||
const paddle::DataType &ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_local_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_code_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_code_zp_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_local_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_code_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_code_zp_dtype,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
return {permute_input_dtype};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
|
||||
*
|
||||
* This operator performs the expert computation in MoE architecture, including:
|
||||
* 1. First linear transformation (FFN1) with optional quantization
|
||||
* 2. SwiGLU activation function
|
||||
* 3. Second linear transformation (FFN2) with optional quantization
|
||||
*
|
||||
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
|
||||
*
|
||||
* Inputs:
|
||||
* - permute_input: Permuted input tensor organized by expert
|
||||
* Shape: [total_tokens * top_k, hidden_size]
|
||||
* dtype: bfloat16/float16 (or int8 for w4a8)
|
||||
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
|
||||
* Shape: [num_experts]
|
||||
* dtype: int64
|
||||
* - ffn1_weight: First FFN layer weights
|
||||
* Shape: [num_experts, inter_size * 2, hidden_size]
|
||||
* dtype: Same as input (unquantized) or int8 (quantized)
|
||||
* - ffn2_weight: Second FFN layer weights
|
||||
* Shape: [num_experts, hidden_size, inter_size]
|
||||
* dtype: Same as input (unquantized) or int8 (quantized)
|
||||
* - ffn1_bias: Optional bias for first FFN layer
|
||||
* Shape: [num_experts, inter_size * 2]
|
||||
* dtype: Same as input
|
||||
* - ffn1_scale: Quantization scales for first FFN layer
|
||||
* Shape: [num_experts, inter_size * 2]
|
||||
* dtype: Same as input
|
||||
* - ffn2_scale: Quantization scales for second FFN layer
|
||||
* Shape: [num_experts, hidden_size]
|
||||
* dtype: Same as input
|
||||
*
|
||||
* Outputs:
|
||||
* - output_tensor: Output tensor after MoE FFN computation
|
||||
* Shape: Same as permute_input
|
||||
* dtype: Same as input (or ffn1_scale dtype for w4a8)
|
||||
*
|
||||
* Attributes:
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
* Affects activation function implementation
|
||||
*
|
||||
* Note:
|
||||
* - Low latency mode uses specialized grouped SwiGLU implementation
|
||||
*/
|
||||
PD_BUILD_STATIC_OP(moe_expert_ffn_wint2)
|
||||
.Inputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_scale"),
|
||||
paddle::Optional("ffn1_local_scale"),
|
||||
paddle::Optional("ffn1_code_scale"),
|
||||
paddle::Optional("ffn1_code_zp"),
|
||||
paddle::Optional("ffn2_local_scale"),
|
||||
paddle::Optional("ffn2_code_scale"),
|
||||
paddle::Optional("ffn2_code_zp")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"used_in_ep_low_latency:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFNWint2))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNWint2InferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNWint2InferDtype));
|
||||
@@ -96,7 +96,10 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t> &permute_indices_per_token_shape,
|
||||
const std::vector<int64_t> &top_k_indices_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &ffn2_bias_shape) {
|
||||
return {ffn_out_shape};
|
||||
const int moe_topk = top_k_indices_shape[1];
|
||||
auto out_shape = ffn_out_shape;
|
||||
if (out_shape[0] != -1) out_shape[0] /= moe_topk;
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
|
||||
@@ -88,6 +88,7 @@ void moe_topk_select_kernel(const T* input,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
cudaGetLastError();
|
||||
}
|
||||
else {
|
||||
assert(k<=TPB);
|
||||
|
||||
1122
custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu
Normal file
1122
custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu
Normal file
File diff suppressed because it is too large
Load Diff
37
custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h
Normal file
37
custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h
Normal file
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "paddle/phi/api/include/api.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
|
||||
#include "moe/moe_wna16_marlin_utils/kernel.h"
|
||||
#include "moe/moe_wna16_marlin_utils/types.h"
|
||||
|
||||
std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
|
||||
const paddle::Tensor& a,
|
||||
const paddle::optional<paddle::Tensor>& c_or_none,
|
||||
const paddle::Tensor& b_q_weight,
|
||||
const paddle::Tensor& b_scales,
|
||||
const paddle::optional<paddle::Tensor>& global_scale_or_none,
|
||||
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
|
||||
const paddle::optional<paddle::Tensor>& g_idx_or_none,
|
||||
const paddle::optional<paddle::Tensor>& perm_or_none,
|
||||
const paddle::Tensor& workspace,
|
||||
const paddle::Tensor& sorted_token_ids,
|
||||
const paddle::Tensor& expert_ids,
|
||||
const paddle::Tensor& num_tokens_post_padded,
|
||||
const paddle::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
const std::string& b_q_type_str,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
63
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h
Normal file
63
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/phi/api/include/context_pool.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_context.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
using DeviceIndex = int8_t;
|
||||
using StreamId = int64_t;
|
||||
|
||||
class CUDAStream {
|
||||
public:
|
||||
CUDAStream() {}
|
||||
explicit CUDAStream(const cudaStream_t& stream) : raw_stream_(stream) {}
|
||||
StreamId id() const { return reinterpret_cast<StreamId>(raw_stream_); }
|
||||
|
||||
operator cudaStream_t() const { return raw_stream_; }
|
||||
|
||||
const cudaStream_t& raw_stream() const { return raw_stream_; }
|
||||
|
||||
private:
|
||||
cudaStream_t raw_stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the current CUDA stream, for the passed CUDA device, or for the
|
||||
* current device if no device index is passed. The current CUDA stream
|
||||
* will usually be the default CUDA stream for the device, but it may
|
||||
* be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
|
||||
* or 'CUDAStreamGuard'.
|
||||
*/
|
||||
inline CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1) {
|
||||
if (device_index == -1) {
|
||||
device_index = phi::backends::gpu::GetCurrentDeviceId();
|
||||
}
|
||||
|
||||
return CUDAStream(
|
||||
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream());
|
||||
// LOG(FATAL) << "getCurrentCUDAStream is not implemented";
|
||||
// return *(CUDAStream*)nullptr;
|
||||
}
|
||||
|
||||
cudaStream_t GetCalcStreamFromGroup(int context_ring_id);
|
||||
|
||||
cudaStream_t GetCommStreamFromGroup(int context_ring_id);
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user