[Optimization] Fuse get_max_len and get_kv_max_len (#4369)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* opt split_q_block

* fuse max_lens and max kv_len
This commit is contained in:
Sunny-bot1
2025-10-13 20:35:00 +08:00
committed by GitHub
parent 425205b03c
commit a751d977bc
15 changed files with 29 additions and 116 deletions

View File

@@ -59,7 +59,6 @@ void AppendAttentionKernel(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -103,6 +102,7 @@ void AppendAttentionKernel(
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
auto main_stream = qkv.stream();
static cudaEvent_t main_event;
@@ -245,7 +245,6 @@ void AppendAttentionKernel(
if (max_just_dec_len_this_time > 0) {
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];
cudaStream_t exec_stream;
if (max_enc_len_this_time > 0) {
@@ -371,20 +370,20 @@ void AppendAttentionKernel(
case paddle::DataType::INT8:{
int8_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
break;
}
case paddle::DataType::FLOAT8_E4M3FN:{
phi::dtype::float8_e4m3fn tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
break;
}
}
} else {
data_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
}
if (max_enc_len_this_time > 0) {
cudaEventRecord(decoder_event, exec_stream);
@@ -413,7 +412,6 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
@@ -539,7 +537,6 @@ std::vector<paddle::Tensor> AppendAttention(
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
@@ -616,7 +613,6 @@ void AppendAttentionWithOutput(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -695,7 +691,6 @@ void AppendAttentionWithOutput(
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
@@ -784,7 +779,6 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
@@ -848,7 +842,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
@@ -930,7 +923,6 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const std::vector<int64_t>& fmha_out_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
@@ -987,7 +979,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::DataType& fmha_out_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
@@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention)
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
paddle::Optional("qkv_bias"),
@@ -1105,7 +1095,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
"fmha_out",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),

View File

@@ -19,7 +19,7 @@
template <int THREADBLOCK_SIZE>
__global__ void
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged, const int *seq_mapping,
@@ -37,41 +37,27 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
int max_just_dec_merged_len_this_time_this_thread = 0;
int max_system_len_this_thread = 0;
int max_dec_len_without_system_this_thread = 0;
int max_len_kv_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
const int seq_len_this_time = seq_lens_this_time[i];
const int seq_len_decoder = seq_lens_decoder[i];
max_len_this_time_this_thread =
max(seq_len_this_time, max_len_this_time_this_thread);
max_len_encoder_this_thread =
max(seq_lens_encoder[i], max_len_encoder_this_thread);
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread);
if (seq_len_this_time <= 0)
continue;
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
max_len_this_thread =
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
max(seq_len_decoder + seq_len_this_time, max_len_this_thread);
max_just_dec_len_this_thread =
max(max_just_dec_len_this_thread, max_just_dec_len_now);
if (system_lens) {
const int real_bid = seq_mapping[i];
const int system_len_now = system_lens[real_bid];
max_system_len_this_thread =
max(max_system_len_this_thread, system_len_now);
max_dec_len_without_system_this_thread =
max(max_dec_len_without_system_this_thread,
max_just_dec_len_now - system_len_now);
}
}
if (system_lens) {
for (int i = tid; i < batch_size; i += blockDim.x) {
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
if (ori_seq_len_this_time <= 0)
continue;
const int max_just_dec_merged_len_this_time_now =
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
max_just_dec_merged_len_this_time_this_thread =
max(max_just_dec_merged_len_this_time_this_thread,
max_just_dec_merged_len_this_time_now);
}
if (seq_len_decoder == 0)
continue;
max_len_kv_this_thread =
max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread);
}
int total_max_len_this_time =
BlockReduce(temp_storage)
@@ -94,6 +80,8 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
int total_dec_len_without_system =
BlockReduce(temp_storage)
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
int total_max_len_kv =
BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp<int>());
if (tid == 0) {
max_lens[0] = total_max_len_this_time;
max_lens[1] = total_max_len_encoder;
@@ -103,6 +91,7 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
max_lens[5] = total_just_dec_merged;
max_lens[6] = total_system_len;
max_lens[7] = total_dec_len_without_system;
max_lens[8] = total_max_len_kv;
}
}
@@ -256,29 +245,6 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
}
}
template <int THREADBLOCK_SIZE>
__global__ void
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
const int *seq_lens_decoder, const int batch_size) {
const int tid = threadIdx.x;
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int max_len_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
if (seq_lens_decoder[i] == 0)
continue;
max_len_this_thread =
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
}
int total =
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
if (tid == 0) {
*max_seq_lens_out = total;
}
}
void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
@@ -295,7 +261,6 @@ void GetBlockShapeAndSplitKVBlock(
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
@@ -319,15 +284,7 @@ void GetBlockShapeAndSplitKVBlock(
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
int max_kv_len_this_time = max_len_cpu_ptr[8];
// decoder
if (max_dec_len_this_time > 0) {
@@ -430,7 +387,7 @@ void GetBlockShapeAndSplitKVBlock(
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
@@ -492,7 +449,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({

View File

@@ -64,7 +64,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
const paddle::Tensor &set_max_lengths,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &qkv_bias,
@@ -106,7 +106,7 @@ void AppendAttentionWithOutput(
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
const paddle::Tensor &set_max_lengths,
paddle::Tensor &fmha_out,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
@@ -315,7 +315,6 @@ void GetBlockShapeAndSplitKVBlock(
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,