// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "decode_attention_func.cuh" #define CHECK(call) \ do \ { \ const cudaError_t error_code = call; \ if (error_code != cudaSuccess) \ { \ printf("CUDA Error:\n"); \ printf(" File: %s\n", __FILE__); \ printf(" Line %d:\n", __LINE__); \ printf(" Error code:%d\n", error_code); \ printf(" Error text:%s\n", cudaGetErrorString(error_code)); \ exit(1); \ } \ }while(0) template __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim] const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads] const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads] const int * __restrict__ seq_lens_q, const int * __restrict__ seq_lens_kv, const int * __restrict__ cu_seqlens_q, const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] OutT * __restrict__ out, // [token_num, num_heads, head_dim] const float in_scale, const int num_chunks, const int chunk_size, const int max_seq_len, const int num_heads, const int head_dim) { const int vid = threadIdx.x, ty = threadIdx.y; const int qid = blockIdx.x, hid = blockIdx.y; const int seq_len_q = seq_lens_q[qid]; if (seq_len_q == 0) return; int seq_len_kv = seq_lens_kv[qid]; if (seq_len_kv == 0) return; seq_len_kv += seq_len_q; const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) { return; } __shared__ T smem[bdy * HEAD_DIM]; __shared__ T md_smem[bdy * 2]; const int start_token_ids = cu_seqlens_q[qid]; using LoadT = AlignedVector; LoadT load_vec; LoadT res_vec; if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { *((half2*)(&res_vec) + i) = make_half2(0, 0); } } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); } } T m; T d = 1.f; if constexpr (std::is_same::value) { m = __float2half(-5e4f); } else if constexpr (std::is_same::value) { m = __float2bfloat16(-3.38953e38f); } // merge per ty #pragma unroll 2 for (int i = ty; i < num_chunks_this_seq; i += bdy) { uint32_t offset = (qid * num_chunks + i) * num_heads + hid; T m_prev = m; T d_prev = d; const T m_now = multi_m[offset]; const T d_now = multi_d[offset]; m = m_prev > m_now ? m_prev : m_now; offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size; Load(&multi_out[offset], &load_vec); const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m); d = d * scale1 + d_now * scale2; #pragma once for (int j = 0; j < vec_size; j++) { res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2; } } // store ty res Store(res_vec, &smem[ty * head_dim + vid * vec_size]); md_smem[2 * ty] = m; md_smem[2 * ty + 1] = d; __syncthreads(); // merge bdy softmax_state_t st{}; const uint32_t iter_num = min(num_chunks_this_seq, bdy); #pragma once for (int i = 0; i < iter_num; i++) { Load(&smem[i * head_dim + vid * vec_size], &load_vec); const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; st.merge(load_vec, m_tmp, d_tmp); } st.normalize(); AlignedVector out_vec; #pragma unroll for (int i = 0; i < vec_size; ++i) { out_vec[i] = static_cast(st.o[i]); } Store(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]); } template __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim] CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] CacheT * __restrict__ cache_v, const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] const int * __restrict__ seq_lens_q, const int * __restrict__ seq_lens_kv, const int * __restrict__ cu_seqlens_q, const int * __restrict__ block_table, // [bsz, block_num_per_seq] const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, const float scale, const float in_scale, const uint32_t chunk_size, T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim] T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads] T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads] OutT * __restrict__ out) { const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t bid = bidx, gid = threadIdx.y; const uint32_t tidx = threadIdx.x; constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE; constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE; constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx; const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid; const uint32_t kv_num_heads = gridDim.z; const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; const int *block_table_now = block_table + bid * max_block_num_per_seq; const uint32_t num_chunks = gridDim.y; const uint32_t chunk_id = blockIdx.y; const uint32_t q_len = seq_lens_q[bid]; if (q_len <= 0) { return; } uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!! if (kv_len <= 0) { return; } kv_len += q_len; const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size); const uint32_t q_start_idx = cu_seqlens_q[bid]; const uint32_t q_write_idx = cu_seqlens_q[bid]; if (chunk_id >= num_chunk_this_seq) { return; } const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0; const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; const uint32_t chunk_len = chunk_end - chunk_start; extern __shared__ uint8_t smem[]; const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK; T *q_smem = reinterpret_cast(smem); // [HEAD_DIM_QK * sizeof(T)] T *cu_q_smem = q_smem + gid * HEAD_DIM_QK; #pragma unroll for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { ((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0]; } __syncthreads(); using VecT = AlignedVector; VecT q_vec; #pragma unroll for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { Load(cu_q_smem + vid * VEC_SIZE, &q_vec); for (uint32_t i = 0; i < VEC_SIZE; ++i) { q_vec[i] *= scale; } Store(q_vec, cu_q_smem + vid * VEC_SIZE); } CacheT *kv_smem = reinterpret_cast(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT)); uint32_t stage_idx = 0; constexpr int loop_times = DEAL_EACH_TIME / bdy; #pragma unroll for (int i = 0; i < NUM_STAGES; ++i) { #pragma unroll for (int j = 0; j < loop_times; ++j) { const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid; const uint32_t k_seq_id = chunk_start + k_seq_offset; produce_kv( kv_smem, cache_k, block_table_now, k_seq_id, k_seq_offset, kv_head_idx, kv_num_heads, tidx, chunk_start, chunk_end ); } commit_group(); stage_idx = (stage_idx + 1) % NUM_STAGES; } softmax_state_ts st; float s[DEAL_EACH_TIME]; const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME); for (int iter = 0; iter < num_iters; ++iter) { wait_group(); __syncthreads(); // compute qk compute_qk( cu_q_smem, kv_smem, chunk_start + iter * DEAL_EACH_TIME, stage_idx, iter * DEAL_EACH_TIME, chunk_len, tidx, gid, scale, s, st ); __syncthreads(); // compute sv compute_sv( s, kv_smem, stage_idx, iter * DEAL_EACH_TIME, chunk_len, tidx, st ); __syncthreads(); #pragma unroll for (int j = 0; j < loop_times; ++j) { const uint32_t k_seq_offset = j * bdy + gid; produce_kv( kv_smem, cache_k, block_table_now, chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME, stage_idx * DEAL_EACH_TIME + k_seq_offset, kv_head_idx, kv_num_heads, tidx, chunk_start, chunk_end ); } commit_group(); stage_idx = (stage_idx + 1) % NUM_STAGES; } wait_group<0>(); __syncthreads(); // normize if not partition_kv for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) { const uint32_t tile_id = vid / bdx; if (!partition_kv || num_chunk_this_seq == 1) { st.normalize(tile_id); } if (partition_kv && num_chunk_this_seq > 1) { const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx; Store(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE); tmp_m[head_idx] = st.m; tmp_d[head_idx] = st.d; } else { Store(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE); } } } template void MultiQueryDecoderAttention( const AppendAttnMetaData& meta_data, cudaStream_t &stream, const paddle::Tensor &q, const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] const paddle::optional& attn_mask, const paddle::optional& shift_bias, const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, const paddle::Tensor &seq_lens_kv, const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const int max_seq_len, const int max_dec_len, const float rope_scale, const float rope_theta, const float softmax_scale, const float in_scale, paddle::Tensor *out) { using NV_TYPE = typename cascade_attn_type_traits::type; auto num_heads = meta_data.q_num_heads; auto kv_num_heads = meta_data.kv_num_heads; auto token_num = meta_data.token_nums; auto bsz = meta_data.batch_size; auto max_block_num_per_seq = meta_data.max_blocks_per_seq; constexpr int num_stages = NUM_STAGE; constexpr int vec_size = 16 / sizeof(T); // 8 16 32 constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32 constexpr int blockxc = HEAD_DIM_QK / cache_vec_size; constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size; constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32; constexpr int blocky = GROUP_SIZE; const int gridx = bsz; constexpr int num_threads = blockx * blocky; auto splitkv_kernel = multi_query_decode_attention_kernel; uint32_t cache_smem_bytes = 0; const T *shift_bias_ptr = shift_bias ? shift_bias.get().data() : nullptr; const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data() : nullptr; cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T); const uint32_t chunk_size = get_max_partition_size(bsz); const int num_chunks = div_up(max_dec_len, chunk_size); size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T); if (smem_size >= 48 * 1024) { cudaFuncSetAttribute( splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } const int dev_id = 0; int sm_count; int act_blocks_per_sm; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( &act_blocks_per_sm, splitkv_kernel, num_threads, smem_size); assert(act_blocks_per_sm > 1); const int num_blocks_per_wave = sm_count * act_blocks_per_sm; const int num_blocks_need = gridx * num_chunks * kv_num_heads; const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); dim3 grids(gridx, num_chunks, kv_num_heads); dim3 blocks(blockx, blocky); if (num_chunks <= 1) { auto no_splitkv_kernel = multi_query_decode_attention_kernel; if (smem_size >= 48 * 1024) { cudaFuncSetAttribute( no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } no_splitkv_kernel<<>>( reinterpret_cast(const_cast(q.data())), reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), reinterpret_cast(const_cast(shift_bias_ptr)), reinterpret_cast(const_cast(smooth_weight_ptr)), seq_lens_q.data(), seq_lens_kv.data(), cu_seqlens_q.data(), block_table.data(), max_seq_len, max_dec_len, max_block_num_per_seq, softmax_scale, in_scale, chunk_size, nullptr, nullptr, nullptr, reinterpret_cast(const_cast(out->data())) ); // CHECK(cudaGetLastError()); // CHECK(cudaDeviceSynchronize()); } else { auto *allocator = paddle::GetAllocator(q.place()); phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; tmp_workspace = allocator->Allocate( phi::SizeOf(q.dtype()) * static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V)); tmp_m = allocator->Allocate( phi::SizeOf(q.dtype()) * static_cast(bsz * num_chunks * num_heads)); tmp_d = allocator->Allocate( phi::SizeOf(q.dtype()) * static_cast(bsz * num_chunks * num_heads)); splitkv_kernel<<>>( reinterpret_cast(const_cast(q.data())), reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), reinterpret_cast(const_cast(shift_bias_ptr)), reinterpret_cast(const_cast(smooth_weight_ptr)), seq_lens_q.data(), seq_lens_kv.data(), cu_seqlens_q.data(), block_table.data(), max_seq_len, max_dec_len, max_block_num_per_seq, softmax_scale, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), reinterpret_cast(tmp_m->ptr()), reinterpret_cast(tmp_d->ptr()), reinterpret_cast(const_cast(out->data())) ); // CHECK(cudaGetLastError()); // CHECK(cudaDeviceSynchronize()); constexpr int mblockx = HEAD_DIM_V / vec_size; constexpr int bdy = 256 / mblockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(mblockx, bdy); merge_varlen_multi_chunks_v2_kernel<<>>( reinterpret_cast(tmp_workspace->ptr()), reinterpret_cast(tmp_m->ptr()), reinterpret_cast(tmp_d->ptr()), seq_lens_q.data(), seq_lens_kv.data(), cu_seqlens_q.data(), reinterpret_cast(const_cast(shift_bias_ptr)), reinterpret_cast(const_cast(smooth_weight_ptr)), reinterpret_cast(const_cast(out->data())), in_scale, num_chunks, chunk_size, max_seq_len, num_heads, HEAD_DIM_V ); } // CHECK(cudaGetLastError()); // CHECK(cudaDeviceSynchronize()); } template void DecodeMLAAttentionKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor &q, // [token_num, num_heads, head_dim] const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, const paddle::optional& attn_mask, const paddle::optional& shift_bias, const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, int max_dec_len, float softmax_scale, float in_scale, bool causal, cudaStream_t &stream, paddle::Tensor *out) { const auto token_num = meta_data.token_nums; const auto block_size = meta_data.block_size; const auto bsz = meta_data.batch_size; const auto num_heads = meta_data.q_num_heads; const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; const auto head_dim_qk = meta_data.head_dims; const auto head_dim_v = meta_data.head_dims_v; const float rope_scale = 0.0; const float rope_theta = 0.0; const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); const uint32_t num_stage = get_cascade_attention_num_stages(); const uint32_t num_threads = get_cascade_attention_num_threads(); DISPATCH_CAUSAL(causal, CAUSAL, {DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, {DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK, {DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, {MultiQueryDecoderAttention( meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q, block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); } template void DecodeMLAAttentionKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor &q, // [token_num, num_heads, head_dim] const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, const paddle::optional& attn_mask, const paddle::optional& shift_bias, const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, int max_dec_len, float softmax_scale, float in_scale, bool causal, cudaStream_t &stream, paddle::Tensor *out); template void DecodeMLAAttentionKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor &q, // [token_num, num_heads, head_dim] const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, const paddle::optional& attn_mask, const paddle::optional& shift_bias, const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, int max_dec_len, float softmax_scale, float in_scale, bool causal, cudaStream_t &stream, paddle::Tensor *out);