mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)
* reset decoder_block_shape_q buffer * refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch * update decode_max_tile_size * fix pre-commit * update block_multihead_attn_backend * update flas attn backend * update MLA Attention * update XPU Attention * update gcu,iluvatar model runner * Update MTP * fix MTP bug
This commit is contained in:
@@ -195,22 +195,25 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num)
|
||||
{
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_len_tensor =
|
||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor, bsz);
|
||||
|
||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
|
||||
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
|
||||
// max_just_dec_merged_len_this_time, max_system_len,
|
||||
// max_just_dec_len_without_system
|
||||
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
auto max_len_cpu_ptr = max_len_cpu.data<int>();
|
||||
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor_gpu, bsz);
|
||||
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||
|
||||
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||
int max_len_this_time = max_len_cpu_ptr[0];
|
||||
int max_enc_len_this_time = max_len_cpu_ptr[1];
|
||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||
@@ -222,14 +225,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor decoder_batch_ids;
|
||||
paddle::Tensor decoder_tile_ids_per_batch;
|
||||
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
@@ -291,92 +291,64 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
}
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu =
|
||||
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu /*cpu*/,
|
||||
max_len_cpu};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_this_time_dtype) {
|
||||
return {
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape) {
|
||||
std::vector<int64_t> dynamic_shape = {-1};
|
||||
|
||||
return {dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
{1},
|
||||
{8}};
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks"),
|
||||
paddle::Optional("decoder_batch_ids"),
|
||||
paddle::Optional("decoder_tile_ids_per_batch"),
|
||||
paddle::Optional("decoder_num_blocks"),
|
||||
paddle::Optional("max_len_kv"), "set_max_lengths"})
|
||||
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
|
||||
"group_size: int", "block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
|
||||
|
||||
Reference in New Issue
Block a user