From f08fb25cfe59b83d87cd5555532f579af32930e0 Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Tue, 9 Dec 2025 19:41:33 +0800 Subject: [PATCH] [Others] Maintain the mtp branch temporarily. (#5447) --- .../append_attn/append_attention_func.cuh | 25 +- .../multiquery_attention_c16_impl.cuh | 212 +++++++++++----- .../multiquery_attention_c4_impl.cuh | 240 ++++++++++++------ .../multiquery_attention_c8_impl.cuh | 240 ++++++++++++------ 4 files changed, 508 insertions(+), 209 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 9f0b9eba1..74de2f39e 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2451,6 +2451,7 @@ __global__ void merge_multi_chunks_v2_kernel( if (bid == -1) { continue; } + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) continue; int seq_len_kv = seq_lens_kv[bid]; @@ -2494,14 +2495,32 @@ __global__ void merge_multi_chunks_v2_kernel( } #pragma unroll 2 for (int i = ty; i < num_chunks_this_seq; i += bdy) { - uint32_t offset = (qid * num_chunks + i) * num_heads + hid; + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (qid * num_chunks + i) * num_heads + hid; + } else { + offset = + ((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks + + i) * + num_heads + + hid; + } float m_prev = m; float d_prev = d; const float m_now = multi_m[offset]; const float d_now = multi_d[offset]; m = max(m_prev, m_now); - offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + - vid * vec_size; + if (ENABLE_PREFILL) { + offset = + (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; + } else { + offset = ((bid * speculate_max_draft_token_num + local_seq_id) * + num_chunks * num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + } Load(&multi_out[offset], &load_vec); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const T scale1_T = static_cast(scale1), diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 8bbc7727b..66eb4d032 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -134,9 +134,17 @@ __global__ void multi_query_append_attention_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } } else { o_base_ptr_int8 = out + o_offset; } @@ -386,8 +394,18 @@ __global__ void multi_query_append_attention_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -524,9 +542,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel( chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); } else { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } } const int *mask_offset_this_seq = @@ -794,8 +814,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel( offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } else { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; @@ -1026,51 +1050,95 @@ void MultiQueryAppendAttention( sliding_window); // merge constexpr int vec_size = num_elems_per_128b(); - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); // 128k is too large - dim3 blocks_merge(blockx, blocky); - auto *kernelFn = merge_multi_chunks_v2_kernel; - launchWithPdlWhenEnabled( - kernelFn, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_decoder_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_v2_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } } } else { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; @@ -1189,15 +1257,31 @@ void MultiQueryAppendAttention( phi::SizeOf(paddle::DataType::FLOAT32) * static_cast(bsz * num_chunks * num_heads)); } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } } launchWithPdlWhenEnabled( split_kv_kernel, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index 9629acf5d..4f7091395 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -169,9 +169,17 @@ __global__ void multi_query_append_attention_c4_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } } else { o_base_ptr_int8 = out + o_offset; } @@ -477,8 +485,18 @@ __global__ void multi_query_append_attention_c4_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -651,9 +669,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); } else { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } } const int *mask_offset_this_seq = @@ -969,8 +989,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } else { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; @@ -1161,15 +1185,30 @@ void MultiQueryAppendC4Attention( sliding_window); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1220,49 +1259,92 @@ void MultiQueryAppendC4Attention( sliding_window); // merge constexpr int vec_size = num_elems_per_128b(); - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), num_heads); - dim3 blocks_merge(blockx, blocky); - launchWithPdlWhenEnabled( - merge_multi_chunks_v2_kernel, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_decoder_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } } } else { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; @@ -1402,15 +1484,31 @@ void MultiQueryAppendC4Attention( phi::SizeOf(paddle::DataType::FLOAT32) * static_cast(bsz * num_chunks * num_heads)); } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } } launchWithPdlWhenEnabled( split_kv_kernel, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index dc8e3b5cd..28df1b405 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -178,9 +178,17 @@ __global__ void multi_query_append_attention_c8_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } } else { o_base_ptr_int8 = out + o_offset; } @@ -524,8 +532,18 @@ __global__ void multi_query_append_attention_c8_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -702,9 +720,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); } else { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } } const int *mask_offset_this_seq = @@ -1063,8 +1083,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } else { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; @@ -1288,15 +1312,30 @@ void MultiQueryAppendC8Attention( sliding_window); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1341,49 +1380,92 @@ void MultiQueryAppendC8Attention( sliding_window); // merge constexpr int vec_size = num_elems_per_128b(); - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), num_heads); - dim3 blocks_merge(blockx, blocky); - launchWithPdlWhenEnabled( - merge_multi_chunks_v2_kernel, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_decoder_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } } } else { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; @@ -1555,15 +1637,31 @@ void MultiQueryAppendC8Attention( phi::SizeOf(paddle::DataType::FLOAT32) * static_cast(bsz * num_chunks * num_heads)); } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } } launchWithPdlWhenEnabled( split_kv_kernel,