mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			241 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			241 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| #pragma once
 | |
| 
 | |
| #include "helper.h"
 | |
| #include "mem_util.cuh"
 | |
| #include "utils.cuh"
 | |
| 
 | |
| template <typename T, int VecSize = 1>
 | |
| __global__ void decode_absorb_cache_kernel(
 | |
|     const T* __restrict__ kv_nope,  // [bsz, kv_num_heads, pe_size] 512
 | |
|     const T* __restrict__ kv_pe,  // [bsz, kv_num_heads, nope_size] 64
 | |
|     T* __restrict__ kv_cache,    // [num_blocks, kv_num_heads, block_size,
 | |
|                                   // nope_size]
 | |
|     const int* __restrict__ block_tables,     // [bsz, max_blocks_per_seq]
 | |
|     const int* __restrict__ cu_seqlens_q,
 | |
|     const int* __restrict__ seq_lens,          // [bsz]
 | |
|     const int* __restrict__ seq_lens_encoder,  // [bsz]
 | |
|     const int max_seq_len,
 | |
|     const int max_blocks_per_seq,
 | |
|     const int kv_num_heads,
 | |
|     const int nope_size,
 | |
|     const int pe_size,
 | |
|     const int block_size,
 | |
|     const uint32_t elem_cnt) {
 | |
|   using LoadT = AlignedVector<T, VecSize>;
 | |
|   constexpr int HalfVecSize = VecSize / 2;
 | |
|   LoadT src_vec;
 | |
| 
 | |
|   int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
 | |
|   const uint32_t nope_hidden_size = kv_num_heads * nope_size;
 | |
|   const uint32_t pe_hidden_size = kv_num_heads * pe_size;
 | |
|   const uint32_t all_size = nope_size + pe_size;
 | |
|   const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
 | |
| 
 | |
|   for (int32_t linear_index = global_thread_idx * VecSize,
 | |
|                step = gridDim.x * blockDim.x * VecSize;
 | |
|        linear_index < elem_cnt;
 | |
|        linear_index += step) {
 | |
|     const int ori_bi = linear_index / hidden_size;
 | |
|     const int bias = linear_index % hidden_size;
 | |
|     const int start_token_idx = cu_seqlens_q[ori_bi];
 | |
|     if (seq_lens_encoder[ori_bi] > 0) return;
 | |
|     const int write_seq_id = seq_lens[ori_bi];
 | |
| 
 | |
|     if (write_seq_id == 0) continue;
 | |
| 
 | |
|     const int* block_table_now = nullptr;
 | |
| 
 | |
|     block_table_now = block_tables + ori_bi * max_blocks_per_seq;
 | |
|     const int block_idx = block_table_now[write_seq_id / block_size];
 | |
|     const int block_offset = write_seq_id % block_size;
 | |
| 
 | |
|     if (bias < nope_hidden_size) { // pe
 | |
|       const uint32_t inner_bias = bias;
 | |
|       const uint32_t hi = inner_bias / nope_size;
 | |
|       const uint32_t h_bias = inner_bias % nope_size;
 | |
|       const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
 | |
|                              hi * block_size * all_size +
 | |
|                              block_offset * all_size + h_bias;
 | |
|       const uint32_t ori_idx =
 | |
|           start_token_idx * nope_hidden_size + inner_bias;
 | |
|       Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
 | |
|       Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
 | |
|     } else {
 | |
|       const uint32_t inner_bias = bias - nope_hidden_size;
 | |
|       const uint32_t hi = inner_bias / pe_size;
 | |
|       const uint32_t h_bias = inner_bias % pe_size;
 | |
|       const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
 | |
|                              hi * block_size * all_size +
 | |
|                              block_offset * all_size + nope_size + h_bias;
 | |
|       const uint32_t ori_idx =
 | |
|           start_token_idx * pe_hidden_size + inner_bias;
 | |
|       Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
 | |
|       Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename T, int VecSize = 1>
 | |
| __global__ void speculate_decode_absorb_cache_kernel(
 | |
|     const T* __restrict__ kv_nope,  // [bsz, kv_num_heads, pe_size] 512
 | |
|     const T* __restrict__ kv_pe,  // [bsz, kv_num_heads, nope_size] 64
 | |
|     T* __restrict__ kv_cache,    // [num_blocks, kv_num_heads, block_size,
 | |
|                                   // nope_size]
 | |
|     const int* __restrict__ block_tables,     // [bsz, max_blocks_per_seq]
 | |
|     const int* __restrict__ batch_id_per_token,
 | |
|     const int* __restrict__ cu_seqlens_q,
 | |
|     const int* __restrict__ seq_lens,          // [bsz]
 | |
|     const int* __restrict__ seq_lens_encoder,  // [bsz]
 | |
|     const int max_seq_len,
 | |
|     const int max_blocks_per_seq,
 | |
|     const int kv_num_heads,
 | |
|     const int nope_size,
 | |
|     const int pe_size,
 | |
|     const int block_size,
 | |
|     const uint32_t elem_cnt) {
 | |
|   using LoadT = AlignedVector<T, VecSize>;
 | |
|   constexpr int HalfVecSize = VecSize / 2;
 | |
|   LoadT src_vec;
 | |
| 
 | |
|   int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
 | |
|   const uint32_t nope_hidden_size = kv_num_heads * nope_size;
 | |
|   const uint32_t pe_hidden_size = kv_num_heads * pe_size;
 | |
|   const uint32_t all_size = nope_size + pe_size;
 | |
|   const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
 | |
| 
 | |
|   for (int32_t linear_index = global_thread_idx * VecSize,
 | |
|                step = gridDim.x * blockDim.x * VecSize;
 | |
|        linear_index < elem_cnt;
 | |
|        linear_index += step) {
 | |
|     const int token_id = linear_index / hidden_size;
 | |
|     const int ori_bi = batch_id_per_token[token_id];
 | |
|     if (seq_lens[ori_bi] == 0) continue;
 | |
|     const int bias = linear_index % hidden_size;
 | |
|     const int start_token_idx = cu_seqlens_q[ori_bi];
 | |
|     const int write_seq_id =
 | |
|         seq_lens[ori_bi] + token_id - start_token_idx;
 | |
|     if (write_seq_id == 0) continue;
 | |
| 
 | |
|     const int* block_table_now = nullptr;
 | |
| 
 | |
|     block_table_now = block_tables + ori_bi * max_blocks_per_seq;
 | |
|     const int block_idx = block_table_now[write_seq_id / block_size];
 | |
|     const int block_offset = write_seq_id % block_size;
 | |
|     if (block_idx < 0) {
 | |
|       printf(
 | |
|           "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
 | |
|           "%d %d %d %d\n",
 | |
|           block_idx,
 | |
|           write_seq_id,
 | |
|           ori_bi,
 | |
|           seq_lens[ori_bi],
 | |
|           token_id,
 | |
|           cu_seqlens_q[ori_bi]);
 | |
|     }
 | |
|     if (bias < nope_hidden_size) { // pe
 | |
|       const uint32_t inner_bias = bias;
 | |
|       const uint32_t hi = inner_bias / nope_size;
 | |
|       const uint32_t h_bias = inner_bias % nope_size;
 | |
|       const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
 | |
|                              hi * block_size * all_size +
 | |
|                              block_offset * all_size + h_bias;
 | |
|       const uint32_t ori_idx =
 | |
|           token_id * nope_hidden_size + inner_bias;
 | |
|       Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
 | |
|       Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
 | |
|     } else {
 | |
|       const uint32_t inner_bias = bias - nope_hidden_size;
 | |
|       const uint32_t hi = inner_bias / pe_size;
 | |
|       const uint32_t h_bias = inner_bias % pe_size;
 | |
|       const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
 | |
|                              hi * block_size * all_size +
 | |
|                              block_offset * all_size + nope_size + h_bias;
 | |
|       const uint32_t ori_idx =
 | |
|           token_id * pe_hidden_size + inner_bias;
 | |
|       Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
 | |
|       Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename T, int VecSize = 1>
 | |
| __global__ void prefill_absorb_cache_kernel(
 | |
|     const T* __restrict__ kv_nope,  // [bsz, kv_num_heads, pe_size] 512
 | |
|     const T* __restrict__ kv_pe,  // [bsz, kv_num_heads, nope_size] 64
 | |
|     T* __restrict__ kv_cache,    // [num_blocks, kv_num_heads, block_size,
 | |
|                                   // nope_size]
 | |
|     const int* __restrict__ block_tables,     // [bsz, max_blocks_per_seq]
 | |
|     const int* __restrict__ batch_id_per_token,
 | |
|     const int* __restrict__ cu_seqlens_q,
 | |
|     const int* __restrict__ seq_lens,          // [bsz]
 | |
|     const int* __restrict__ seq_lens_decoder,  // [bsz]
 | |
|     const int max_seq_len,
 | |
|     const int max_blocks_per_seq,
 | |
|     const int kv_num_heads,
 | |
|     const int nope_size,
 | |
|     const int pe_size,
 | |
|     const int block_size,
 | |
|     const uint32_t elem_cnt) {
 | |
|   using LoadT = AlignedVector<T, VecSize>;
 | |
|   LoadT src_vec;
 | |
| 
 | |
|   int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
 | |
|   const uint32_t nope_hidden_size = kv_num_heads * nope_size;
 | |
|   const uint32_t pe_hidden_size = kv_num_heads * pe_size;
 | |
|   const uint32_t all_size = nope_size + pe_size;
 | |
|   const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
 | |
| 
 | |
|   for (int32_t linear_index = global_thread_idx * VecSize,
 | |
|                step = gridDim.x * blockDim.x * VecSize;
 | |
|        linear_index < elem_cnt;
 | |
|        linear_index += step) {
 | |
|     const uint32_t token_idx = linear_index / hidden_size;
 | |
|     const uint32_t bias = linear_index % hidden_size;
 | |
|     const uint32_t ori_bi = batch_id_per_token[token_idx];
 | |
|     if (seq_lens[ori_bi] == 0) continue;
 | |
|     const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
 | |
| 
 | |
|     const int* block_table_now = nullptr;
 | |
|     block_table_now = block_tables + ori_bi * max_blocks_per_seq;
 | |
|     const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
 | |
|     const uint32_t block_offset = ori_seq_id % block_size;
 | |
| 
 | |
|     if (bias < nope_hidden_size) { // pe
 | |
|       const uint32_t inner_bias = bias;
 | |
|       const uint32_t hi = inner_bias / nope_size;
 | |
|       const uint32_t h_bias = inner_bias % nope_size;
 | |
|       const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
 | |
|                              hi * block_size * all_size +
 | |
|                              block_offset * all_size + h_bias;
 | |
|       const uint32_t ori_idx =
 | |
|           token_idx * nope_hidden_size + inner_bias;
 | |
|       Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
 | |
|       Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
 | |
|     } else {
 | |
|       const uint32_t inner_bias = bias - nope_hidden_size;
 | |
|       const uint32_t hi = inner_bias / pe_size;
 | |
|       const uint32_t h_bias = inner_bias % pe_size;
 | |
|       const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
 | |
|                              hi * block_size * all_size +
 | |
|                              block_offset * all_size + nope_size + h_bias;
 | |
|       const uint32_t ori_idx =
 | |
|           token_idx * pe_hidden_size + inner_bias;
 | |
|       Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
 | |
|       Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
 | |
|     }
 | |
|   }
 | |
| }
 | 
