Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -17,15 +17,12 @@
#include "paddle/phi/core/memory/memcpy.h"
template <int THREADBLOCK_SIZE>
__global__ void GetMaxLenKernel(const int *seq_lens,
const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged,
const int *seq_mapping,
const int *system_lens,
int *max_lens,
const int batch_size) {
__global__ void
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged, const int *seq_mapping,
const int *system_lens, int *max_lens, const int batch_size) {
const int tid = threadIdx.x;
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
@@ -41,43 +38,61 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
int max_dec_len_without_system_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
const int seq_len_this_time = seq_lens_this_time[i];
max_len_this_time_this_thread = max(seq_len_this_time,
max_len_this_time_this_thread);
max_len_encoder_this_thread = max(seq_lens_encoder[i],
max_len_encoder_this_thread);
max_len_this_time_this_thread =
max(seq_len_this_time, max_len_this_time_this_thread);
max_len_encoder_this_thread =
max(seq_lens_encoder[i], max_len_encoder_this_thread);
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
if (seq_len_this_time <= 0) continue;
if (seq_len_this_time <= 0)
continue;
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
max_len_this_thread = max(seq_lens[i] + seq_len_this_time,
max_len_this_thread);
max_just_dec_len_this_thread = max(max_just_dec_len_this_thread,
max_just_dec_len_now);
max_len_this_thread =
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
max_just_dec_len_this_thread =
max(max_just_dec_len_this_thread, max_just_dec_len_now);
if (system_lens) {
const int real_bid = seq_mapping[i];
const int system_len_now = system_lens[real_bid];
max_system_len_this_thread = max(max_system_len_this_thread, system_len_now);
max_dec_len_without_system_this_thread = max(max_dec_len_without_system_this_thread,
max_just_dec_len_now - system_len_now);
max_system_len_this_thread =
max(max_system_len_this_thread, system_len_now);
max_dec_len_without_system_this_thread =
max(max_dec_len_without_system_this_thread,
max_just_dec_len_now - system_len_now);
}
}
if (system_lens) {
for (int i = tid; i < batch_size; i += blockDim.x) {
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
if (ori_seq_len_this_time <= 0) continue;
const int max_just_dec_merged_len_this_time_now = seq_lens_encoder_merged[i] > 0 ?
0 : ori_seq_len_this_time;
max_just_dec_merged_len_this_time_this_thread = max(max_just_dec_merged_len_this_time_this_thread,
max_just_dec_merged_len_this_time_now);
if (ori_seq_len_this_time <= 0)
continue;
const int max_just_dec_merged_len_this_time_now =
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
max_just_dec_merged_len_this_time_this_thread =
max(max_just_dec_merged_len_this_time_this_thread,
max_just_dec_merged_len_this_time_now);
}
}
int total_max_len_this_time = BlockReduce(temp_storage).Reduce(max_len_this_time_this_thread, MaxOp<int>());
int total_max_len_encoder = BlockReduce(temp_storage).Reduce(max_len_encoder_this_thread, MaxOp<int>());
int total_max_len_decoder = BlockReduce(temp_storage).Reduce(max_len_decoder_this_thread, MaxOp<int>());
int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total_just_dec = BlockReduce(temp_storage).Reduce(max_just_dec_len_this_thread, MaxOp<int>());
int total_just_dec_merged = BlockReduce(temp_storage).Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
int total_system_len = BlockReduce(temp_storage).Reduce(max_system_len_this_thread, MaxOp<int>());
int total_dec_len_without_system = BlockReduce(temp_storage).Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
int total_max_len_this_time =
BlockReduce(temp_storage)
.Reduce(max_len_this_time_this_thread, MaxOp<int>());
int total_max_len_encoder =
BlockReduce(temp_storage)
.Reduce(max_len_encoder_this_thread, MaxOp<int>());
int total_max_len_decoder =
BlockReduce(temp_storage)
.Reduce(max_len_decoder_this_thread, MaxOp<int>());
int total =
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total_just_dec = BlockReduce(temp_storage)
.Reduce(max_just_dec_len_this_thread, MaxOp<int>());
int total_just_dec_merged =
BlockReduce(temp_storage)
.Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
int total_system_len = BlockReduce(temp_storage)
.Reduce(max_system_len_this_thread, MaxOp<int>());
int total_dec_len_without_system =
BlockReduce(temp_storage)
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
if (tid == 0) {
max_lens[0] = total_max_len_this_time;
max_lens[1] = total_max_len_encoder;
@@ -90,30 +105,22 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
}
}
void GetMaxLen(const paddle::Tensor& seq_lens_tensor,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
paddle::Tensor &max_len_tensor,
const int batch_size) {
void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
paddle::Tensor &max_len_tensor, const int batch_size) {
constexpr int blockSize = 1024;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, seq_lens_encoder.stream()>>>(
seq_lens_tensor.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
nullptr,
nullptr,
nullptr,
nullptr,
max_len_tensor.data<int>(),
batch_size);
seq_lens_tensor.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(), nullptr, nullptr, nullptr, nullptr,
max_len_tensor.data<int>(), batch_size);
}
__global__ void split_q_block(const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_encoder,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
const int bsz,
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
const int num_rows_per_block,
const int group_size) {
if (threadIdx.x == 0) {
@@ -124,8 +131,7 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
seq_len = 0;
}
const int loop_times =
div_up(seq_len * group_size, num_rows_per_block);
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
@@ -136,14 +142,12 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
}
}
__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_encoder,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
const int bsz,
const int pad_len,
const int num_row_per_block) {
__global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
const int pad_len, const int num_row_per_block) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
@@ -165,50 +169,46 @@ __global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
}
template <int THREADBLOCK_SIZE>
__global__ void get_max_len_kv_ernel(int* max_seq_lens_out,
const int* seq_lens_this_time,
const int* seq_lens_decoder,
const int batch_size) {
__global__ void
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
const int *seq_lens_decoder, const int batch_size) {
const int tid = threadIdx.x;
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int max_len_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
if (seq_lens_decoder[i] == 0) continue;
max_len_this_thread = max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
if (seq_lens_decoder[i] == 0)
continue;
max_len_this_thread =
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
}
int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total =
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
if (tid == 0) {
*max_seq_lens_out = total;
}
}
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cum_offsets,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream();
int bsz = cum_offsets.shape()[0];
auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(
seq_lens_decoder,
seq_lens_this_time,
seq_lens_encoder,
max_len_tensor,
bsz);
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor, bsz);
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time,
// max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
// max_just_dec_merged_len_this_time, max_system_len,
// max_just_dec_len_without_system
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
auto max_len_cpu_ptr = max_len_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
@@ -229,67 +229,67 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor decoder_batch_ids;
paddle::Tensor decoder_tile_ids_per_batch;
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
bsz
);
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu =
max_len_kv.copy_to(paddle::CPUPlace(), false);
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
seq_lens_decoder.data<int>(),
// sequence_lengths->data<int>(),
seq_lens_encoder.data<int>(),
kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(),
kv_num_blocks_x.data<int>(),
bsz,
block_size,
block_size
);
seq_lens_decoder.data<int>(),
// sequence_lengths->data<int>(),
seq_lens_encoder.data<int>(), kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
const uint32_t encoder_max_tile_size_per_bs_q = div_up(
(max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
nullptr,
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
encoder_batch_ids.data<int>(),
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(),
bsz,
encoder_block_shape_q,
group_size);
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
}
if (max_just_dec_len_this_time > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
@@ -297,24 +297,26 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
paddle::DataType::INT32, seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}
return {encoder_batch_ids,
@@ -331,28 +333,22 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cum_offsets_dtype) {
return {paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype,
const paddle::DataType &cum_offsets_dtype) {
return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cum_offsets_shape) {
const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape,
const std::vector<int64_t> &cum_offsets_shape) {
std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape,
@@ -369,9 +365,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time",
"cum_offsets"})
.Outputs({paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
@@ -382,12 +376,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
paddle::Optional("decoder_batch_ids"),
paddle::Optional("decoder_tile_ids_per_batch"),
paddle::Optional("decoder_num_blocks"),
paddle::Optional("max_len_kv"),
"set_max_lengths"})
.Attrs({"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
paddle::Optional("max_len_kv"), "set_max_lengths"})
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
"group_size: int", "block_size: int",
"decoder_step_token_num: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))

View File

@@ -337,6 +337,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (deal_each_time == 64) { \
constexpr size_t DEAL_EACH_TIME = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the deal_each_time", deal_each_time); \
}
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
@@ -346,6 +348,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (num_threads == 256) { \
constexpr size_t NUM_THREADS = 256; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_threads", num_threads); \
}
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
@@ -376,6 +380,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \