diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index b2fe4c6f6..909f888a4 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -13,1607 +13,9 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" -#include "append_attention_kernel.h" +#include "multiquery_attention_c8_kernel.h" -template -__global__ void multi_query_append_attention_c8_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int *__restrict__ mask_offset, - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - const int num_blocks_x_cpu, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = - HEAD_DIM / num_elems_per_128b(); // 128 / 8 = 16 - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / num_elems_per_128b(); // 128 / 16 = 8 - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / num_elems_per_128b(); // 64 / 16 = 4 - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ - return; - } - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - - T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; - T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; - } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; - } - } - - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t kv_d_stride = BLOCK_SIZE; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = - (tile_id * NUM_WARPS + wid) * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } else { - o_base_ptr_int8 = out + o_offset; - } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - load_q_global_smem( - q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - T* k_smem_scale = nullptr; - T* v_smem_scale = nullptr; - if constexpr (IsDynamicC8) { - k_smem_scale = reinterpret_cast(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); - v_smem_scale = k_smem_scale + num_frags_z * 16; - } - - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : mask_offset ? 0 : chunk_len) / - (num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 4 + tid / 8, - tid % 8); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64 - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_d_stride + - tid % 4 * num_elems_per_128b(); - - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - if constexpr (IsDynamicC8) { - produce_k_dynamic_scale( - k_smem_scale, - cache_k_scale_reg, - block_table_now, - cache_k_scale, - kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } - wait_group<1>(); - __syncthreads(); - // s = qk - compute_qk_c8( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - cache_k_scale_reg, - s_frag); - - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(nullptr, - q_base_seq_id_this_block, - kv_idx_base, - q_len, - kv_len, - chunk_end, - -1, - s_frag, - mask_offset_this_seq); - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - const int ori_kv_idx_base = kv_idx_base; - kv_idx_base += num_frags_z * 16; - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - if constexpr (IsDynamicC8) { - produce_v_dynamic_scale( - v_smem_scale, - cache_v_scale_reg, - block_table_now, - cache_v_scale, - ori_kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v_c8( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); - __syncthreads(); - - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - - } - wait_group<0>(); - __syncthreads(); - - if constexpr (!partition_kv) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if constexpr (partition_kv) { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } - - - if constexpr (partition_kv) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } -} - -template -__global__ void multi_query_append_attention_c8_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - const int num_blocks_x_cpu, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5, - const uint32_t attn_mask_len = -1) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); - static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ - return; - } - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; - T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; - } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; - } - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t kv_d_stride = BLOCK_SIZE; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if (num_chunks_this_seq <= 1) { - o_base_ptr_int8 = out + o_offset; - } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); // 16 * 16 - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale_multi_warps( - &qo_smem, scale); - - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - T* k_smem_scale = nullptr; - T* v_smem_scale = nullptr; - if constexpr (IsDynamicC8) { - k_smem_scale = reinterpret_cast(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); - v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16; - } - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - NUM_WARP_KV * num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : mask_offset ? 0 : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, - (wid % 2) * num_frags_z + (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 4 + tid / 8, - tid % - 8); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_d_stride + - tid % 4 * num_elems_per_128b(); - - // load BLOCK_SIZE * HEAD_DIM each time - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - if constexpr (IsDynamicC8) { - produce_k_dynamic_scale( - k_smem_scale, - cache_k_scale_reg, - block_table_now, - cache_k_scale, - kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } - wait_group<1>(); - __syncthreads(); - - // s = qk - compute_qk_c8( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - cache_k_scale_reg, - s_frag); - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq); - - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - const uint32_t ori_kv_idx_base = kv_idx_base; - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - if constexpr (IsDynamicC8) { - produce_v_dynamic_scale( - v_smem_scale, - cache_v_scale_reg, - block_table_now, - cache_v_scale, - ori_kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } - wait_group<1>(); - __syncthreads(); - - // compute sfm * v - compute_sfm_v_c8_iter_sq_bvec( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); - __syncthreads(); - - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - merge_block_res_v2( - o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); - - if (num_chunks_this_seq <= 1) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if (num_chunks_this_seq <= 1) { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride * num_chunks, - HEAD_DIM); - } - - if (num_chunks_this_seq > 1) { - if (wid == 0) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } - } -} - -template -void MultiQueryAppendC8Attention( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::Tensor &cache_k_scale, - const paddle::Tensor &cache_v_scale, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - using OUT_NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - - constexpr uint32_t num_warps = 4; - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; - constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - - auto *allocator = paddle::GetAllocator(qkv.place()); - - const float scale = 1.f / sqrt(HEAD_DIM); - bool is_scale_channel_wise = false; - if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { - is_scale_channel_wise = true; - } - - if constexpr (NUM_WARP_Q == 4) { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; - constexpr uint32_t smem_size = - num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + - num_frags_z * 16 * sizeof(T) * 2; - auto split_kv_kernel = - multi_query_append_attention_c8_kernel; - if (is_scale_channel_wise) { - split_kv_kernel = - multi_query_append_attention_c8_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_c8_kernel; - if (is_scale_channel_wise) { - nosplit_kv_kernel = - multi_query_append_attention_c8_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } else { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; - constexpr uint32_t smem_size = - num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + - NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; - auto split_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - if (is_scale_channel_wise) { - split_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - - const int num_chunks = div_up(max_seq_len, chunk_size); - uint32_t attn_mask_len; - if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; - } else { - attn_mask_len = -1; - } - - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 0) { - auto nosplit_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - if (is_scale_channel_wise) { - nosplit_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num, - attn_mask_len); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - meta_data.mask_offset, - attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - num_blocks_x_cpu, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num, - attn_mask_len); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } -} - -template +template void CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] @@ -1726,3 +128,447 @@ void CascadeAppendAttentionC8Kernel( out); })})})})})})}) } + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index 2cc069592..a41c904f1 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -15,6 +15,7 @@ #include "helper.h" #include "utils.cuh" +#include "append_attention_c8_impl.cuh" template void CascadeAppendAttentionC16Kernel( @@ -61,52 +62,6 @@ void CascadeAppendAttentionC16Kernel( cudaStream_t& stream, paddle::Tensor* out); -template -void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - template void CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, @@ -233,7 +188,7 @@ void CascadeAppendAttentionKernel( stream, out); } else if (cache_quant_type_str == "cache_int8") { - CascadeAppendAttentionC8Kernel(meta_data, + CascadeAppendAttentionC8Kernel(meta_data, qkv, cache_k, cache_v, diff --git a/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py b/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py new file mode 100644 index 000000000..6e7e492f9 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""generate multiquery_attention_c8_kernel template instantiation.""" + +from pathlib import Path + +TEMPLATE_DIR = Path("gpu_ops/append_attn/template_instantiation/autogen") +TEMPLATE_DIR.mkdir(exist_ok=True) + +DISPATCH_PARAMS = { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1], +} + +DATA_TYPE_COMBINATIONS = [ + ("paddle::float16", "paddle::float16", "float16_float16"), + ("paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"), + ("paddle::float16", "int8_t", "float16_int8"), + ("paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"), + ("paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"), + ("paddle::bfloat16", "int8_t", "bfloat16_int8"), +] + +MAX_INSTANCES_PER_FILE = 60 + + +def get_num_warp_q(block_shape_q): + if block_shape_q <= 32: + return 1 + else: + return 4 + + +def generate_file_header(): + return """// Generated by autogen_template_instantiation.py - Do not edit. + +#pragma once + +#include "../../multiquery_attention_c8_impl.cuh" +""" + + +def generate_template_instantiation(t_in, t_out, params): + num_warp_q = get_num_warp_q(params["BLOCK_SHAPE_Q"]) + template_args = f"<{t_in}, {params['GROUP_SIZE']}, {params['HEAD_DIM']}, {params['BLOCK_SIZE']}, {params['CAUSAL']}, {params['BLOCK_SHAPE_Q']}, {num_warp_q}, {t_out}, {params['ENABLE_PREFILL']}, {params['IsFP8']}, {params['IsDynamicC8']}>" + + return f""" +template void MultiQueryAppendC8Attention{template_args}( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out); + +""" + + +def generate_combinations_for_type(t_in, t_out): + combinations = [] + for group_size in DISPATCH_PARAMS["GROUP_SIZE"]: + for head_dim in DISPATCH_PARAMS["HEAD_DIM"]: + for block_size in DISPATCH_PARAMS["BLOCK_SIZE"]: + for causal in DISPATCH_PARAMS["CAUSAL"]: + for block_shape_q in DISPATCH_PARAMS["BLOCK_SHAPE_Q"]: + for enable_prefill in DISPATCH_PARAMS["ENABLE_PREFILL"]: + for is_fp8 in DISPATCH_PARAMS["IsFP8"]: + for is_dynamic_c8 in DISPATCH_PARAMS["IsDynamicC8"]: + params = { + "GROUP_SIZE": group_size, + "HEAD_DIM": head_dim, + "BLOCK_SIZE": block_size, + "CAUSAL": causal, + "BLOCK_SHAPE_Q": block_shape_q, + "ENABLE_PREFILL": enable_prefill, + "IsFP8": is_fp8, + "IsDynamicC8": is_dynamic_c8, + } + combinations.append(params) + + return combinations + + +def split_combinations(combinations, max_per_file): + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks + + +def generate_file_content(t_in, t_out, t_out_name, file_index, combinations): + content = generate_file_header() + for params in combinations: + content += generate_template_instantiation(t_in, t_out, params) + + return content + + +def main(): + for t_in, t_out, t_out_name in DATA_TYPE_COMBINATIONS: + combinations = generate_combinations_for_type(t_in, t_out) + if combinations: + chunks = split_combinations(combinations, MAX_INSTANCES_PER_FILE) + for i, chunk in enumerate(chunks): + filename = f"multiquery_attention_c8_{t_out_name}_part_{i:02d}.cu" + filepath = TEMPLATE_DIR / filename + content = generate_file_content(t_in, t_out, t_out_name, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + +if __name__ == "__main__": + main() diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh new file mode 100644 index 000000000..58e1a5bc0 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -0,0 +1,1614 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" +#include "multiquery_attention_c8_kernel.h" + +template +__global__ void multi_query_append_attention_c8_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + constexpr uint32_t num_vecs_per_head = + HEAD_DIM / num_elems_per_128b(); // 128 / 8 = 16 + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); // 128 / 16 = 8 + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); // 64 / 16 = 4 + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + + T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; + T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + T* k_smem_scale = nullptr; + T* v_smem_scale = nullptr; + if constexpr (IsDynamicC8) { + k_smem_scale = reinterpret_cast(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale = k_smem_scale + num_frags_z * 16; + } + + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 : chunk_len) / + (num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, + tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64 + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale( + k_smem_scale, + cache_k_scale_reg, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } + wait_group<1>(); + __syncthreads(); + // s = qk + compute_qk_c8( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(nullptr, + q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + -1, + s_frag, + mask_offset_this_seq); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + const int ori_kv_idx_base = kv_idx_base; + kv_idx_base += num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale( + v_smem_scale, + cache_v_scale_reg, + block_table_now, + cache_v_scale, + ori_kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v_c8( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c8_warp1_4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + //When cudagraph capture prefill, may launch more gridDim.x + if(btid >= static_cast(num_blocks_x_cpu)){ + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; + T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } + const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + T* k_smem_scale = nullptr; + T* v_smem_scale = nullptr; + if constexpr (IsDynamicC8) { + k_smem_scale = reinterpret_cast(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16; + } + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, + tid % + 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + // load BLOCK_SIZE * HEAD_DIM each time + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale( + k_smem_scale, + cache_k_scale_reg, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } + wait_group<1>(); + __syncthreads(); + + // s = qk + compute_qk_c8( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq); + + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + const uint32_t ori_kv_idx_base = kv_idx_base; + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale( + v_smem_scale, + cache_v_scale_reg, + block_table_now, + cache_v_scale, + ori_kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } + wait_group<1>(); + __syncthreads(); + + // compute sfm * v + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +void MultiQueryAppendC8Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; + constexpr uint32_t smem_size = + num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + num_frags_z * 16 * sizeof(T) * 2; + auto split_kv_kernel = + multi_query_append_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = + multi_query_append_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c8_kernel; + if (is_scale_channel_wise) { + nosplit_kv_kernel = + multi_query_append_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + auto split_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + + const int num_chunks = div_up(max_seq_len, chunk_size); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 0) { + auto nosplit_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + if (is_scale_channel_wise) { + nosplit_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h new file mode 100644 index 000000000..0a7c244d4 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h @@ -0,0 +1,58 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" + +template +void MultiQueryAppendC8Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu deleted file mode 100644 index 757cccaf9..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - - -template void -CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - - - -template void -CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu deleted file mode 100644 index 54b0b0be4..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu deleted file mode 100644 index c6bd95576..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu deleted file mode 100644 index 153b81ee0..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_fp8_kernel.cu deleted file mode 100644 index 7e2539b0a..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_fp8_kernel.cu +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_int8_kernel.cu deleted file mode 100644 index e46fb31c1..000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_int8_kernel.cu +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - const std::string& cache_quant_type_str, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index ee9cb59bf..8636c3de4 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -377,6 +377,7 @@ elif paddle.is_compiled_with_cuda(): if cc >= 80: # append_attention + os.system("python gpu_ops/append_attn/autogen_template_instantiation.py") sources += ["gpu_ops/append_attention.cu"] sources += find_end_files("gpu_ops/append_attn", ".cu") # mla