// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "helper.h" #include "mem_util.cuh" #include "utils.cuh" template __global__ void decode_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, // nope_size] const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] const int max_seq_len, const int max_blocks_per_seq, const int kv_num_heads, const int nope_size, const int pe_size, const int block_size, const uint32_t elem_cnt) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; LoadT src_vec; int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t nope_hidden_size = kv_num_heads * nope_size; const uint32_t pe_hidden_size = kv_num_heads * pe_size; const uint32_t all_size = nope_size + pe_size; const int64_t hidden_size = nope_hidden_size + pe_hidden_size; for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int ori_bi = linear_index / hidden_size; const int bias = linear_index % hidden_size; const int start_token_idx = cu_seqlens_q[ori_bi]; if (seq_lens_encoder[ori_bi] > 0) return; const int write_seq_id = seq_lens[ori_bi]; if (write_seq_id == 0) continue; const int* block_table_now = nullptr; block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; const int block_offset = write_seq_id % block_size; if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + hi * block_size * all_size + block_offset * all_size + h_bias; const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + hi * block_size * all_size + block_offset * all_size + nope_size + h_bias; const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } } } template __global__ void speculate_decode_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, // nope_size] const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] const int max_seq_len, const int max_blocks_per_seq, const int kv_num_heads, const int nope_size, const int pe_size, const int block_size, const uint32_t elem_cnt) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; LoadT src_vec; int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t nope_hidden_size = kv_num_heads * nope_size; const uint32_t pe_hidden_size = kv_num_heads * pe_size; const uint32_t all_size = nope_size + pe_size; const int64_t hidden_size = nope_hidden_size + pe_hidden_size; for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_id = linear_index / hidden_size; const int ori_bi = batch_id_per_token[token_id]; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int start_token_idx = cu_seqlens_q[ori_bi]; const int write_seq_id = seq_lens[ori_bi] + token_id - start_token_idx; if (write_seq_id == 0) continue; const int* block_table_now = nullptr; block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; const int block_offset = write_seq_id % block_size; if (block_idx < 0) { printf( "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " "%d %d %d %d\n", block_idx, write_seq_id, ori_bi, seq_lens[ori_bi], token_id, cu_seqlens_q[ori_bi]); } if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + hi * block_size * all_size + block_offset * all_size + h_bias; const uint32_t ori_idx = token_id * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + hi * block_size * all_size + block_offset * all_size + nope_size + h_bias; const uint32_t ori_idx = token_id * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } } } template __global__ void prefill_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, // nope_size] const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_decoder, // [bsz] const int max_seq_len, const int max_blocks_per_seq, const int kv_num_heads, const int nope_size, const int pe_size, const int block_size, const uint32_t elem_cnt) { using LoadT = AlignedVector; LoadT src_vec; int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t nope_hidden_size = kv_num_heads * nope_size; const uint32_t pe_hidden_size = kv_num_heads * pe_size; const uint32_t all_size = nope_size + pe_size; const int64_t hidden_size = nope_hidden_size + pe_hidden_size; for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const uint32_t token_idx = linear_index / hidden_size; const uint32_t bias = linear_index % hidden_size; const uint32_t ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int* block_table_now = nullptr; block_table_now = block_tables + ori_bi * max_blocks_per_seq; const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; const uint32_t block_offset = ori_seq_id % block_size; if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + hi * block_size * all_size + block_offset * all_size + h_bias; const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + hi * block_size * all_size + block_offset * all_size + nope_size + h_bias; const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } } }