diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index cc537e46c..7a6f68d01 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -13,1300 +13,7 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" -#include "append_attention_kernel.h" - -template -__global__ void multi_query_append_attention_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - T *__restrict__ cache_v, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int *__restrict__ mask_offset, - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - const int num_blocks_x_cpu, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ - return; - } - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = - (tile_id * NUM_WARPS + wid) * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } else { - o_base_ptr_int8 = out + o_offset; - } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - load_q_global_smem( - q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); - - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); - - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : mask_offset ? 0 : chunk_len) / - (num_frags_z * 16); - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset(tid % 16, tid / 16); - - uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); - - uint32_t kv_idx_base = chunk_start; - int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - const uint32_t const_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - // s = qk - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(nullptr, - q_base_seq_id_this_block, - kv_idx_base, - q_len, - kv_len, - chunk_end, - -1, - s_frag, - mask_offset_this_seq); - - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += num_frags_z * 16; - block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - if (block_id < 0) { - block_id = 0; - } - cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); - - __syncthreads(); - cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - if constexpr (!partition_kv) { - normalize_d(o_frag, d_frag); - } - if constexpr (partition_kv) { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } - - - if constexpr (partition_kv) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } -} - -template -__global__ void multi_query_append_attention_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - T *__restrict__ cache_v, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - const int num_blocks_x_cpu, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5, - const uint32_t attn_mask_len = -1) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); - static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ - return; - } - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if (num_chunks_this_seq <= 1) { - o_base_ptr_int8 = out + o_offset; - } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); // 16 * 16 - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale_multi_warps( - &qo_smem, scale); - - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - NUM_WARP_KV * num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len, - chunk_start))) - : mask_offset ? 0 : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); - - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_z * 16 + tid % 16, tid / 16); - uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); - - uint32_t kv_idx_base = chunk_start; - int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - const uint32_t const_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - // s = qk - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq); - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; - block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - if (block_id < 0) { - block_id = 0; - } - cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); - __syncthreads(); - - cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - merge_block_res_v2( - o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); - - if (num_chunks_this_seq <= 1) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if (num_chunks_this_seq <= 1) { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride * num_chunks, - HEAD_DIM); - } - - if (num_chunks_this_seq > 1) { - if (wid == 0) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } - } -} - -template -void MultiQueryAppendAttention( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - using OUT_NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - - constexpr uint32_t num_warps = 4; - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; - constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - - auto *allocator = paddle::GetAllocator(qkv.place()); - - const float scale = 1.f / sqrt(HEAD_DIM); - - if constexpr (NUM_WARP_Q == 4) { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; - constexpr uint32_t smem_size = - (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * - HEAD_DIM * sizeof(T); - auto split_kv_kernel = multi_query_append_attention_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); // 128k is too large - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } else { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; - constexpr uint32_t smem_size = - (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * - sizeof(T); - auto split_kv_kernel = - multi_query_append_attention_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - - uint32_t attn_mask_len; - if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; - } else { - attn_mask_len = -1; - } - - const int num_chunks = div_up(max_seq_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 0) { - auto nosplit_kv_kernel = - multi_query_append_attention_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num, - attn_mask_len); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num, - attn_mask_len); - - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } -} +#include "multiquery_attention_c16_kernel.h" template void CascadeAppendAttentionC16Kernel( @@ -1414,3 +121,267 @@ void CascadeAppendAttentionC16Kernel( out); })})})})})}) } + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 49317bfdf..9dcc78398 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -13,1538 +13,7 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" -#include "append_attention_kernel.h" - -template -__global__ void multi_query_append_attention_c4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int *__restrict__ mask_offset, - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - const int num_blocks_x_cpu, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / 2 / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / 2 / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ - return; - } - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - - const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; - const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; - const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; - T *cache_k_scale_smem = reinterpret_cast( - smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); - T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; - T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; - T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; -#pragma unroll - for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { - cache_k_scale_smem[i] = cache_k_scale_now[i]; - cache_k_zero_point_smem[i] = cache_k_zp_now[i]; - cache_v_scale_smem[i] = cache_v_scale_now[i]; - cache_v_zero_point_smem[i] = cache_v_zp_now[i]; - } - - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_b_stride = HEAD_DIM / 2; - const uint32_t kv_d_stride = BLOCK_SIZE / 2; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = - (tile_id * NUM_WARPS + wid) * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } else { - o_base_ptr_int8 = out + o_offset; - } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_x * 16 + tid % 16, tid / 16); - load_q_global_smem( - q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); - - T cache_k_scale_frag[num_frags_y][4]; - T cache_k_zp_frag[num_frags_y][4]; - T magic_number; - if constexpr (std::is_same::value) { - magic_number = static_cast(1032.f); - } else { - magic_number = static_cast(136.f); - } -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); - *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + - 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4 + 4); -#pragma unroll - for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { - cache_k_zp_frag[fy][zp_i] += magic_number; // 128 + 8 - } - } - T cache_v_scale_frag[num_frags_y][2]; - T cache_v_zp_frag[num_frags_y][2]; -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; - cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; - cache_v_zp_frag[fy][0] = - cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; - cache_v_zp_frag[fy][1] = - cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; - } - - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); - - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : mask_offset ? 0 : chunk_len) / - (num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, - tid % - 4); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_b_stride + - tid % 4 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 16 + tid / 2) * kv_d_stride + - tid % 2 * num_elems_per_128b(); - - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - compute_qk_c4( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - s_frag, - cache_k_scale_frag, - cache_k_zp_frag); - - if (iter >= mask_check_iteration) { - mask_s(nullptr, - q_base_seq_id_this_block, - kv_idx_base, - q_len, - kv_len, - chunk_end, - -1, - s_frag, - mask_offset_this_seq); - } - - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += num_frags_z * 16; - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - wait_group<1>(); - __syncthreads(); - - compute_sfm_v_c4(&v_smem, - &v_smem_offset_r, - s_frag, - o_frag, - d_frag, - cache_v_scale_frag, - cache_v_zp_frag); - __syncthreads(); - - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - if constexpr (!partition_kv) { - normalize_d(o_frag, d_frag); - } - - if constexpr (partition_kv) { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } - - if constexpr (partition_kv) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } -} - -template -__global__ void multi_query_append_attention_c4_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - const int num_blocks_x_cpu, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5, - const uint32_t attn_mask_len = -1) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / 2 / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / 2 / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); - static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ - return; - } - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; - const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; - const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; - T *cache_k_scale_smem = reinterpret_cast( - smem + NUM_WARP_Q * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); - T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; - T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; - T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; -#pragma unroll - for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { - cache_k_scale_smem[i] = cache_k_scale_now[i]; - cache_k_zero_point_smem[i] = cache_k_zp_now[i]; - cache_v_scale_smem[i] = cache_v_scale_now[i]; - cache_v_zero_point_smem[i] = cache_v_zp_now[i]; - } - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_b_stride = HEAD_DIM / 2; - const uint32_t kv_d_stride = BLOCK_SIZE / 2; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if (num_chunks_this_seq <= 1) { - o_base_ptr_int8 = out + o_offset; - } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale_multi_warps( - &qo_smem, scale); - - T cache_k_scale_frag[num_frags_y][4]; - T cache_k_zp_frag[num_frags_y][4]; - T magic_number; - if constexpr (std::is_same::value) { - magic_number = static_cast(1032.f); - } else { - magic_number = static_cast(136.f); - } -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); - *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + - 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4 + 4); -#pragma unroll - for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { - cache_k_zp_frag[fy][zp_i] += magic_number; - } - } - T cache_v_scale_frag[num_frags_y][2]; - T cache_v_zp_frag[num_frags_y][2]; -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; - cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; - cache_v_zp_frag[fy][0] = - cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; - cache_v_zp_frag[fy][1] = - cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; - } - - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - NUM_WARP_KV * num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len, - chunk_start))) - : mask_offset ? 0 : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - wid * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, - tid % - 4); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 16 + tid / 2, tid % 2); - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_b_stride + - tid % 4 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 16 + tid / 2) * kv_d_stride + - tid % 2 * num_elems_per_128b(); - - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - compute_qk_c4( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - s_frag, - cache_k_scale_frag, - cache_k_zp_frag); - if (iter >= mask_check_iteration) { - mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq); - } - - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v_c4(&v_smem, - &v_smem_offset_r, - s_frag, - o_frag, - d_frag, - cache_v_scale_frag, - cache_v_zp_frag); - __syncthreads(); - - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - merge_block_res_v2( - o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); - - if (num_chunks_this_seq <= 1) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if (num_chunks_this_seq <= 1) { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride * num_chunks, - HEAD_DIM); - } - - if (num_chunks_this_seq > 1) { - if (wid == 0) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } - } -} - -template -void MultiQueryAppendC4Attention( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::Tensor &cache_k_scale, - const paddle::Tensor &cache_v_scale, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - using OUT_NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - - constexpr uint32_t num_warps = 4; - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; - constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - - auto *allocator = paddle::GetAllocator(qkv.place()); - - const float scale = 1.f / sqrt(HEAD_DIM); - - if constexpr (NUM_WARP_Q == 4) { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; - constexpr uint32_t smem_size = - num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + - HEAD_DIM * 4 * sizeof(T); - auto split_kv_kernel = - multi_query_append_attention_c4_kernel; - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); - assert(act_blocks_per_sm > 1); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; - const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); - const float ratio = static_cast(num_blocks_need) / - static_cast(num_blocks_per_wave); - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_c4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } else { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; - constexpr uint32_t smem_size = - num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + - HEAD_DIM * 4 * sizeof(T); - auto split_kv_kernel = - multi_query_append_attention_c4_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); - assert(act_blocks_per_sm > 1); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; - const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); - const float ratio = static_cast(num_blocks_need) / - static_cast(num_blocks_per_wave); - - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - - const int num_chunks = div_up(max_seq_len, chunk_size); - uint32_t attn_mask_len; - if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; - } else { - attn_mask_len = -1; - } - - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 0) { - auto nosplit_kv_kernel = - multi_query_append_attention_c4_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num, - attn_mask_len); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num, - attn_mask_len); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } -} +#include "multiquery_attention_c4_kernel.h" template void CascadeAppendAttentionC4Kernel( @@ -1656,3 +125,267 @@ void CascadeAppendAttentionC4Kernel( out); })})})})})}) } + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index a41c904f1..174bb612d 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -15,97 +15,9 @@ #include "helper.h" #include "utils.cuh" +#include "append_attention_c16_impl.cuh" #include "append_attention_c8_impl.cuh" - -template -void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - -template -void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); +#include "append_attention_c4_impl.cuh" template void CascadeAppendAttentionKernel( diff --git a/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py b/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py index 6e7e492f9..1a2f27a87 100644 --- a/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py +++ b/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py @@ -11,143 +11,232 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""generate multiquery_attention_c8_kernel template instantiation.""" +"""Universal template instantiation generator - fully based on configuration file template instantiation generation.""" +import argparse +import json +from dataclasses import dataclass from pathlib import Path - -TEMPLATE_DIR = Path("gpu_ops/append_attn/template_instantiation/autogen") -TEMPLATE_DIR.mkdir(exist_ok=True) - -DISPATCH_PARAMS = { - "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], - "HEAD_DIM": [128], - "BLOCK_SIZE": [64], - "CAUSAL": [0, 1], - "BLOCK_SHAPE_Q": [16, 32, 64, 128], - "ENABLE_PREFILL": [0, 1], - "IsFP8": [0, 1], - "IsDynamicC8": [0, 1], -} - -DATA_TYPE_COMBINATIONS = [ - ("paddle::float16", "paddle::float16", "float16_float16"), - ("paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"), - ("paddle::float16", "int8_t", "float16_int8"), - ("paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"), - ("paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"), - ("paddle::bfloat16", "int8_t", "bfloat16_int8"), -] - -MAX_INSTANCES_PER_FILE = 60 +from typing import Any, Dict, List, Optional, Tuple -def get_num_warp_q(block_shape_q): - if block_shape_q <= 32: - return 1 - else: - return 4 +@dataclass +class TemplateConfig: + """Template configuration class.""" + + name: str # Function name + function_name: str # Actual function name + impl_file: str # Implementation file path + template_params: List[str] # Template parameter list (in order) + dispatch_params: Dict[str, List[Any]] # Dispatch parameters + data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix) + max_instances_per_file: int = 60 # Maximum instances per file + file_prefix: str = "" # File prefix + function_signature: str = "" # Function signature template -def generate_file_header(): - return """// Generated by autogen_template_instantiation.py - Do not edit. +class UniversalTemplateInstantiator: + """Universal template instantiator - fully based on configuration file.""" + + def __init__(self, config_file: str): + """Initialize the instantiator.""" + self.config_file = config_file + self.configs = self._load_configs() + + def _load_configs(self) -> Dict[str, TemplateConfig]: + """Load configuration file.""" + with open(self.config_file, "r", encoding="utf-8") as f: + config_data = json.load(f) + + configs = {} + for name, config_dict in config_data.items(): + config = TemplateConfig(**config_dict) + self._validate_config(config) + configs[name] = config + return configs + + def _validate_config(self, config: TemplateConfig): + """Validate configuration completeness.""" + has_t = "T" in config.template_params + has_out_t = "OutT" in config.template_params + + if (has_t or has_out_t) and not config.data_types: + raise ValueError( + f"Configuration '{config.name}' has T or OutT in template_params but no data_types configured" + ) + + special_params = {"T", "OutT", "NUM_WARP_Q"} + for param_name in config.template_params: + if param_name not in special_params and param_name not in config.dispatch_params: + raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params") + + if "NUM_WARP_Q" in config.template_params and "BLOCK_SHAPE_Q" not in config.dispatch_params: + raise ValueError( + f"Template parameter 'NUM_WARP_Q' in '{config.name}' requires 'BLOCK_SHAPE_Q' in dispatch_params" + ) + + def _calculate_num_warp_q(self, block_shape_q: int) -> int: + """Calculate number of warps.""" + if block_shape_q <= 32: + return 1 + else: + return 4 + + def _build_template_args(self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]) -> str: + """Build template arguments.""" + template_args_parts = [] + + for param_name in config.template_params: + if param_name == "T": + if t_in: + template_args_parts.append(t_in) + else: + raise ValueError("Template parameter 'T' requires input type, but data_types is empty or invalid") + elif param_name == "OutT": + if t_out: + template_args_parts.append(t_out) + else: + raise ValueError( + "Template parameter 'OutT' requires output type, but data_types is empty or invalid" + ) + elif param_name == "NUM_WARP_Q": + if "BLOCK_SHAPE_Q" in params: + num_warp_q = self._calculate_num_warp_q(params["BLOCK_SHAPE_Q"]) + template_args_parts.append(str(num_warp_q)) + else: + raise ValueError("Template parameter 'NUM_WARP_Q' requires 'BLOCK_SHAPE_Q' in dispatch_params") + elif param_name in params: + template_args_parts.append(str(params[param_name])) + else: + raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params") + + return f"<{', '.join(template_args_parts)}>" + + def _generate_function_signature(self, config: TemplateConfig, template_args: str) -> str: + """Generate function signature.""" + if config.function_signature: + return config.function_signature.format(function_name=config.function_name, template_args=template_args) + else: + raise ValueError(f"Function signature not found for {config.name}") + + def _generate_file_header(self, config: TemplateConfig) -> str: + """Generate file header.""" + return f"""// Generated by autogen_template_instantiation.py - Do not edit. #pragma once -#include "../../multiquery_attention_c8_impl.cuh" +#include "../../{config.impl_file}" """ + def _generate_template_instantiation( + self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any] + ) -> str: + """Generate template instantiation.""" + template_args = self._build_template_args(config, t_in, t_out, params) + return self._generate_function_signature(config, template_args) -def generate_template_instantiation(t_in, t_out, params): - num_warp_q = get_num_warp_q(params["BLOCK_SHAPE_Q"]) - template_args = f"<{t_in}, {params['GROUP_SIZE']}, {params['HEAD_DIM']}, {params['BLOCK_SIZE']}, {params['CAUSAL']}, {params['BLOCK_SHAPE_Q']}, {num_warp_q}, {t_out}, {params['ENABLE_PREFILL']}, {params['IsFP8']}, {params['IsDynamicC8']}>" + def generate_combinations_for_type(self, config: TemplateConfig, t_in: str, t_out: str) -> List[Dict[str, Any]]: + """Generate parameter combinations for specific type.""" + combinations = [] - return f""" -template void MultiQueryAppendC8Attention{template_args}( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::Tensor &cache_k_scale, - const paddle::Tensor &cache_v_scale, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out); + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return -""" + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + return combinations -def generate_combinations_for_type(t_in, t_out): - combinations = [] - for group_size in DISPATCH_PARAMS["GROUP_SIZE"]: - for head_dim in DISPATCH_PARAMS["HEAD_DIM"]: - for block_size in DISPATCH_PARAMS["BLOCK_SIZE"]: - for causal in DISPATCH_PARAMS["CAUSAL"]: - for block_shape_q in DISPATCH_PARAMS["BLOCK_SHAPE_Q"]: - for enable_prefill in DISPATCH_PARAMS["ENABLE_PREFILL"]: - for is_fp8 in DISPATCH_PARAMS["IsFP8"]: - for is_dynamic_c8 in DISPATCH_PARAMS["IsDynamicC8"]: - params = { - "GROUP_SIZE": group_size, - "HEAD_DIM": head_dim, - "BLOCK_SIZE": block_size, - "CAUSAL": causal, - "BLOCK_SHAPE_Q": block_shape_q, - "ENABLE_PREFILL": enable_prefill, - "IsFP8": is_fp8, - "IsDynamicC8": is_dynamic_c8, - } - combinations.append(params) + def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: + """Split combinations into multiple files.""" + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks - return combinations + def generate_file_content( + self, + config: TemplateConfig, + t_in: str, + t_out: str, + t_out_name: str, + file_index: int, + combinations: List[Dict[str, Any]], + ) -> str: + """Generate file content.""" + content = self._generate_file_header(config) + for params in combinations: + content += self._generate_template_instantiation(config, t_in, t_out, params) -def split_combinations(combinations, max_per_file): - chunks = [] - for i in range(0, len(combinations), max_per_file): - chunk = combinations[i : i + max_per_file] - chunks.append(chunk) - return chunks + return content + def generate_for_function_type(self, function_name: str, output_dir: str): + """Generate template instantiation files for specific function type.""" + if function_name not in self.configs: + raise ValueError(f"Function type '{function_name}' not found in config") -def generate_file_content(t_in, t_out, t_out_name, file_index, combinations): - content = generate_file_header() - for params in combinations: - content += generate_template_instantiation(t_in, t_out, params) + config = self.configs[function_name] + output_path = Path(output_dir) + output_path.mkdir(exist_ok=True) - return content + if not config.data_types: + data_types = [("", "", "")] + else: + data_types = config.data_types + + for t_in, t_out, t_out_name in data_types: + combinations = self.generate_combinations_for_type(config, t_in, t_out) + if combinations: + chunks = self.split_combinations(combinations, config.max_instances_per_file) + for i, chunk in enumerate(chunks): + filename = f"{config.file_prefix}{t_out_name}_part_{i:02d}.cu" + filepath = output_path / filename + content = self.generate_file_content(config, t_in, t_out, t_out_name, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + def generate_all(self, output_dir: str): + """Generate all configured function types.""" + for function_name in self.configs.keys(): + print(f"Generating template instantiations for {function_name}...") + self.generate_for_function_type(function_name, output_dir) + print(f"Completed generating {function_name} template instantiations.") def main(): - for t_in, t_out, t_out_name in DATA_TYPE_COMBINATIONS: - combinations = generate_combinations_for_type(t_in, t_out) - if combinations: - chunks = split_combinations(combinations, MAX_INSTANCES_PER_FILE) - for i, chunk in enumerate(chunks): - filename = f"multiquery_attention_c8_{t_out_name}_part_{i:02d}.cu" - filepath = TEMPLATE_DIR / filename - content = generate_file_content(t_in, t_out, t_out_name, i, chunk) - with open(filepath, "w", encoding="utf-8") as f: - f.write(content) + """Main function.""" + parser = argparse.ArgumentParser(description="Universal template instantiation generator") + parser.add_argument( + "--config", + "-c", + type=str, + default="gpu_ops/append_attn/template_config.json", + help="Configuration file path (JSON format)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="gpu_ops/append_attn/template_instantiation/autogen", + help="Output directory", + ) + + args = parser.parse_args() + + try: + instantiator = UniversalTemplateInstantiator(args.config) + instantiator.generate_all(args.output) + except Exception as e: + print(f"Error: {e}") if __name__ == "__main__": diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh index 3ac80b6cc..4ff4a0229 100644 --- a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh @@ -13,8 +13,8 @@ // limitations under the License. #pragma once - -#include "multi_head_latent_attention_kernel.h" +#include "helper.h" +#include "utils.cuh" template struct softmax_state_t { diff --git a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h index 4d81b99a7..54e4fd6de 100644 --- a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,27 +12,94 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once + #include "helper.h" #include "utils.cuh" +#include "multiquery_decoder_attention_impl.cuh" template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out) { + const auto token_num = meta_data.token_nums; + const auto block_size = meta_data.block_size; + const auto bsz = meta_data.batch_size; + const auto num_heads = meta_data.q_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + const auto head_dim_qk = meta_data.head_dims; + const auto head_dim_v = meta_data.head_dims_v; + const float rope_scale = 0.0; + const float rope_theta = 0.0; + const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); + const uint32_t num_stage = get_cascade_attention_num_stages(); + const uint32_t num_threads = get_cascade_attention_num_threads(); + + DISPATCH_CAUSAL(causal, CAUSAL, + {DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, + {DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK, + {DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, + {MultiQueryDecoderAttention( + meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q, + block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); +} + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh new file mode 100644 index 000000000..74a1e8e53 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -0,0 +1,1308 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "multiquery_attention_c16_kernel.h" + +template +__global__ void multi_query_append_attention_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + T *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * + sizeof(T)); + + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 : chunk_len) / + (num_frags_z * 16); + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx_base = chunk_start; + int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + // s = qk + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(nullptr, + q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + -1, + s_frag, + mask_offset_this_seq); + + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += num_frags_z * 16; + block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + + __syncthreads(); + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + normalize_d(o_frag, d_frag); + } + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_warp1_4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + T *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM * + sizeof(T)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len, + chunk_start))) + : mask_offset ? 0 : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx_base = chunk_start; + int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + // s = qk + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +void MultiQueryAppendAttention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; + constexpr uint32_t smem_size = + (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * + HEAD_DIM * sizeof(T); + auto split_kv_kernel = multi_query_append_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; + constexpr uint32_t smem_size = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + const int num_chunks = div_up(max_seq_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 0) { + auto nosplit_kv_kernel = + multi_query_append_attention_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len); + + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h new file mode 100644 index 000000000..5fa6ecc6f --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h @@ -0,0 +1,54 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" + +template +void MultiQueryAppendAttention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh new file mode 100644 index 000000000..e45889b17 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -0,0 +1,1546 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "multiquery_attention_c4_kernel.h" + +template +__global__ void multi_query_append_attention_c4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / 2 / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / 2 / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = reinterpret_cast( + smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i]; + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_b_stride = HEAD_DIM / 2; + const uint32_t kv_d_stride = BLOCK_SIZE / 2; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; + T magic_number; + if constexpr (std::is_same::value) { + magic_number = static_cast(1032.f); + } else { + magic_number = static_cast(136.f); + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4 + 4); +#pragma unroll + for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { + cache_k_zp_frag[fy][zp_i] += magic_number; // 128 + 8 + } + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = + cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; + cache_v_zp_frag[fy][1] = + cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; + } + + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); + + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 : chunk_len) / + (num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, + tid % + 4); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_b_stride + + tid % 4 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 16 + tid / 2) * kv_d_stride + + tid % 2 * num_elems_per_128b(); + + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + compute_qk_c4( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + s_frag, + cache_k_scale_frag, + cache_k_zp_frag); + + if (iter >= mask_check_iteration) { + mask_s(nullptr, + q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + -1, + s_frag, + mask_offset_this_seq); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += num_frags_z * 16; + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); + + compute_sfm_v_c4(&v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_frag, + cache_v_zp_frag); + __syncthreads(); + + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + normalize_d(o_frag, d_frag); + } + + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c4_warp1_4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / 2 / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / 2 / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = reinterpret_cast( + smem + NUM_WARP_Q * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i]; + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_b_stride = HEAD_DIM / 2; + const uint32_t kv_d_stride = BLOCK_SIZE / 2; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; + T magic_number; + if constexpr (std::is_same::value) { + magic_number = static_cast(1032.f); + } else { + magic_number = static_cast(136.f); + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4 + 4); +#pragma unroll + for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { + cache_k_zp_frag[fy][zp_i] += magic_number; + } + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = + cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; + cache_v_zp_frag[fy][1] = + cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; + } + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len, + chunk_start))) + : mask_offset ? 0 : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, + tid % + 4); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_b_stride + + tid % 4 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 16 + tid / 2) * kv_d_stride + + tid % 2 * num_elems_per_128b(); + + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + compute_qk_c4( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + s_frag, + cache_k_scale_frag, + cache_k_zp_frag); + if (iter >= mask_check_iteration) { + mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v_c4(&v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_frag, + cache_v_zp_frag); + __syncthreads(); + + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +void MultiQueryAppendC4Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; + constexpr uint32_t smem_size = + num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + + HEAD_DIM * 4 * sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_c4_kernel; + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); + assert(act_blocks_per_sm > 1); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; + constexpr uint32_t smem_size = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + + HEAD_DIM * 4 * sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_c4_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); + assert(act_blocks_per_sm > 1); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + + const int num_chunks = div_up(max_seq_len, chunk_size); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 0) { + auto nosplit_kv_kernel = + multi_query_append_attention_c4_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h new file mode 100644 index 000000000..e7184caf9 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h @@ -0,0 +1,58 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" + +template +void MultiQueryAppendC4Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 58e1a5bc0..ab0abf6d9 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" #include "multiquery_attention_c8_kernel.h" template -void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out) { - const auto token_num = meta_data.token_nums; - const auto block_size = meta_data.block_size; - const auto bsz = meta_data.batch_size; - const auto num_heads = meta_data.q_num_heads; - const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; - const auto head_dim_qk = meta_data.head_dims; - const auto head_dim_v = meta_data.head_dims_v; - const float rope_scale = 0.0; - const float rope_theta = 0.0; - const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); - const uint32_t num_stage = get_cascade_attention_num_stages(); - const uint32_t num_threads = get_cascade_attention_num_threads(); - - DISPATCH_CAUSAL(causal, CAUSAL, - {DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, - {DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK, - {DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V, - {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, - {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, - {MultiQueryDecoderAttention( - meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q, - block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); -} - -template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); - -template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h new file mode 100644 index 000000000..a1f8aa7cc --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "decode_attention_func.cuh" + +template +void MultiQueryDecoderAttention( + const AppendAttnMetaData& meta_data, + cudaStream_t &stream, + const paddle::Tensor &q, + const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const int max_seq_len, + const int max_dec_len, + const float rope_scale, + const float rope_theta, + const float softmax_scale, + const float in_scale, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/template_config.json b/custom_ops/gpu_ops/append_attn/template_config.json new file mode 100644 index 000000000..aa91b5ac3 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/template_config.json @@ -0,0 +1,144 @@ +{ + "multiquery_attention_c8": { + "name": "multiquery_attention_c8", + "function_name": "MultiQueryAppendC8Attention", + "impl_file": "multiquery_attention_c8_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM", + "BLOCK_SIZE", + "CAUSAL", + "BLOCK_SHAPE_Q", + "NUM_WARP_Q", + "OutT", + "ENABLE_PREFILL", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "data_types": [ + ["paddle::float16", "paddle::float16", "float16_float16"], + ["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"], + ["paddle::float16", "int8_t", "float16_int8"], + ["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"], + ["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"], + ["paddle::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 80, + "file_prefix": "multiquery_attention_c8_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n" + }, + "multiquery_attention_c4": { + "name": "multiquery_attention_c4", + "function_name": "MultiQueryAppendC4Attention", + "impl_file": "multiquery_attention_c4_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM", + "BLOCK_SIZE", + "CAUSAL", + "BLOCK_SHAPE_Q", + "NUM_WARP_Q", + "OutT", + "ENABLE_PREFILL" + ], + "dispatch_params": { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1] + }, + "data_types": [ + ["paddle::float16", "paddle::float16", "float16_float16"], + ["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"], + ["paddle::float16", "int8_t", "float16_int8"], + ["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"], + ["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"], + ["paddle::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 160, + "file_prefix": "multiquery_attention_c4_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &cache_k_zp,\n const paddle::optional &cache_v_zp,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n" + }, + "multiquery_attention_c16": { + "name": "multiquery_attention_c16", + "function_name": "MultiQueryAppendAttention", + "impl_file": "multiquery_attention_c16_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM", + "BLOCK_SIZE", + "CAUSAL", + "BLOCK_SHAPE_Q", + "NUM_WARP_Q", + "OutT", + "ENABLE_PREFILL" + ], + "dispatch_params": { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1] + }, + "data_types": [ + ["paddle::float16", "paddle::float16", "float16_float16"], + ["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"], + ["paddle::float16", "int8_t", "float16_int8"], + ["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"], + ["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"], + ["paddle::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 160, + "file_prefix": "multiquery_attention_c16_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n" + }, + "multiquery_decoder_attention": { + "name": "multiquery_decoder_attention", + "function_name": "MultiQueryDecoderAttention", + "impl_file": "multiquery_decoder_attention_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM_QK", + "HEAD_DIM_V", + "BLOCK_SIZE", + "CAUSAL", + "NUM_STAGE", + "cache_bytes", + "DEAL_EACH_TIME" + ], + "dispatch_params": { + "GROUP_SIZE": [8, 16, 128], + "HEAD_DIM_QK": [128, 192, 512, 576], + "HEAD_DIM_V": [128, 192, 512, 576], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "NUM_STAGE": [2], + "cache_bytes": [16], + "DEAL_EACH_TIME": [32, 64] + }, + "data_types": [ + ["paddle::float16", "", "float16"], + ["paddle::bfloat16", "", "bfloat16"] + ], + "max_instances_per_file": 60, + "file_prefix": "multiquery_decoder_attention_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData& meta_data,\n cudaStream_t &stream,\n const paddle::Tensor &q,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional& attn_mask,\n const paddle::optional& shift_bias,\n const paddle::optional& smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const int max_seq_len,\n const int max_dec_len,\n const float rope_scale,\n const float rope_theta,\n const float softmax_scale,\n const float in_scale,\n paddle::Tensor *out);\n\n" + } +} diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu deleted file mode 100644 index 93db78513..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu deleted file mode 100644 index 573703648..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu deleted file mode 100644 index 077a5764e..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu deleted file mode 100644 index 436250238..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu deleted file mode 100644 index daaad4de6..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu deleted file mode 100644 index 549f1cec2..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu deleted file mode 100644 index 923f9b0d3..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu deleted file mode 100644 index 888c410bb..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu deleted file mode 100644 index fcef546ea..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu deleted file mode 100644 index 656374937..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu deleted file mode 100644 index fba62df2b..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu deleted file mode 100644 index 7a6e21fa7..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 8636c3de4..a38f14d8c 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -353,6 +353,8 @@ elif paddle.is_compiled_with_cuda(): "-Igpu_ops", "-Ithird_party/nlohmann_json/include", ] + worker_threads = os.cpu_count() + nvcc_compile_args += ["-t", str(worker_threads)] nvcc_version = get_nvcc_version() print(f"nvcc_version = {nvcc_version}")