diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 4936e0172..9f0b9eba1 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2451,7 +2451,6 @@ __global__ void merge_multi_chunks_v2_kernel( if (bid == -1) { continue; } - const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) continue; int seq_len_kv = seq_lens_kv[bid]; @@ -2470,8 +2469,6 @@ __global__ void merge_multi_chunks_v2_kernel( const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); if (num_chunks_this_seq <= 1) { continue; - } else if (!ENABLE_PREFILL) { - continue; } using LoadT = AlignedVector; @@ -2497,32 +2494,14 @@ __global__ void merge_multi_chunks_v2_kernel( } #pragma unroll 2 for (int i = ty; i < num_chunks_this_seq; i += bdy) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (qid * num_chunks + i) * num_heads + hid; - } else { - offset = - ((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks + - i) * - num_heads + - hid; - } + uint32_t offset = (qid * num_chunks + i) * num_heads + hid; float m_prev = m; float d_prev = d; const float m_now = multi_m[offset]; const float d_now = multi_d[offset]; m = max(m_prev, m_now); - if (ENABLE_PREFILL) { - offset = - (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + - vid * vec_size; - } else { - offset = ((bid * speculate_max_draft_token_num + local_seq_id) * - num_chunks * num_heads + - i * num_heads + hid) * - head_dim + - vid * vec_size; - } + offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; Load(&multi_out[offset], &load_vec); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const T scale1_T = static_cast(scale1), diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 54e0715d7..8bbc7727b 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -134,17 +134,9 @@ __global__ void multi_query_append_attention_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } else { o_base_ptr_int8 = out + o_offset; } @@ -394,18 +386,8 @@ __global__ void multi_query_append_attention_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } + uint32_t offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -542,11 +524,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel( chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } } const int *mask_offset_this_seq = @@ -814,12 +794,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; @@ -918,10 +894,7 @@ void MultiQueryAppendAttention( int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } + uint32_t chunk_size = static_cast(encoder_max_partition_size); const int num_chunks = div_up(max_dec_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); @@ -1053,95 +1026,51 @@ void MultiQueryAppendAttention( sliding_window); // merge constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - auto *kernelFn = merge_multi_chunks_decoder_kernel; - launchWithPdlWhenEnabled( - kernelFn, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); // 128k is too large - dim3 blocks_merge(blockx, blocky); - auto *kernelFn = merge_multi_chunks_v2_kernel; - launchWithPdlWhenEnabled( - kernelFn, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_v2_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } else { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; @@ -1173,9 +1102,6 @@ void MultiQueryAppendAttention( cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } uint32_t attn_mask_len; if (attn_mask) { @@ -1263,31 +1189,15 @@ void MultiQueryAppendAttention( phi::SizeOf(paddle::DataType::FLOAT32) * static_cast(bsz * num_chunks * num_heads)); } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); } launchWithPdlWhenEnabled( split_kv_kernel, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index dab3dcee4..9629acf5d 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -169,17 +169,9 @@ __global__ void multi_query_append_attention_c4_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } else { o_base_ptr_int8 = out + o_offset; } @@ -485,18 +477,8 @@ __global__ void multi_query_append_attention_c4_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } + uint32_t offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -669,11 +651,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } } const int *mask_offset_this_seq = @@ -989,12 +969,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; @@ -1108,10 +1084,7 @@ void MultiQueryAppendC4Attention( const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } + uint32_t chunk_size = static_cast(encoder_max_partition_size); const int num_chunks = div_up(max_dec_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); @@ -1188,30 +1161,15 @@ void MultiQueryAppendC4Attention( sliding_window); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1262,92 +1220,49 @@ void MultiQueryAppendC4Attention( sliding_window); // merge constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - launchWithPdlWhenEnabled( - merge_multi_chunks_decoder_kernel, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), num_heads); - dim3 blocks_merge(blockx, blocky); - launchWithPdlWhenEnabled( - merge_multi_chunks_v2_kernel, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } else { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; @@ -1390,9 +1305,6 @@ void MultiQueryAppendC4Attention( static_cast(num_blocks_per_wave); uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } const int num_chunks = div_up(max_seq_len, chunk_size); uint32_t attn_mask_len; @@ -1490,31 +1402,15 @@ void MultiQueryAppendC4Attention( phi::SizeOf(paddle::DataType::FLOAT32) * static_cast(bsz * num_chunks * num_heads)); } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); } launchWithPdlWhenEnabled( split_kv_kernel, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 39b371c78..dc8e3b5cd 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -178,17 +178,9 @@ __global__ void multi_query_append_attention_c8_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } else { o_base_ptr_int8 = out + o_offset; } @@ -532,18 +524,8 @@ __global__ void multi_query_append_attention_c8_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } + uint32_t offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -720,11 +702,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } } const int *mask_offset_this_seq = @@ -1083,12 +1063,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; } tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; @@ -1218,10 +1194,7 @@ void MultiQueryAppendC8Attention( const int dev_id = 0; int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } + uint32_t chunk_size = static_cast(encoder_max_partition_size); const int num_chunks = div_up(max_dec_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); @@ -1315,30 +1288,15 @@ void MultiQueryAppendC8Attention( sliding_window); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); launchWithPdlWhenEnabled( split_kv_kernel, grids, @@ -1383,92 +1341,49 @@ void MultiQueryAppendC8Attention( sliding_window); // merge constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - launchWithPdlWhenEnabled( - merge_multi_chunks_decoder_kernel, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), num_heads); - dim3 blocks_merge(blockx, blocky); - launchWithPdlWhenEnabled( - merge_multi_chunks_v2_kernel, - grids_merge, - blocks_merge, - 0, - stream, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } else { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; @@ -1525,9 +1440,6 @@ void MultiQueryAppendC8Attention( int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } const int num_chunks = div_up(max_seq_len, chunk_size); uint32_t attn_mask_len; @@ -1643,31 +1555,15 @@ void MultiQueryAppendC8Attention( phi::SizeOf(paddle::DataType::FLOAT32) * static_cast(bsz * num_chunks * num_heads)); } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); } launchWithPdlWhenEnabled( split_kv_kernel,