mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			345 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			345 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| #pragma once
 | |
| 
 | |
| #include "helper.h"
 | |
| #include "utils.cuh"
 | |
| 
 | |
| template <typename T, typename OutT>
 | |
| void CascadeAppendAttentionC16Kernel(
 | |
|     const AppendAttnMetaData& meta_data,
 | |
|     const paddle::Tensor& qkv,  // [token_num, num_heads, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_k,  // [max_block_num, num_heads, block_size, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_v,  // [max_block_num, num_heads, head_dim, block_size]
 | |
|     const paddle::optional<paddle::Tensor>& attn_mask,
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         shift_bias,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         smooth_weight,  // [num_kv_heads, head_dim]
 | |
|     const paddle::Tensor& seq_lens_q,
 | |
|     const paddle::Tensor& seq_lens_kv,
 | |
|     const paddle::Tensor& seq_lens_encoder,
 | |
|     const paddle::Tensor& batch_id_per_token,
 | |
|     const paddle::Tensor& cu_seqlens_q,
 | |
|     const paddle::Tensor& block_table,
 | |
|     const paddle::Tensor& batch_ids,
 | |
|     const paddle::Tensor& tile_ids_per_batch,
 | |
|     const int num_blocks,
 | |
|     const int block_shape_q,
 | |
|     const int max_seq_len,
 | |
|     const int max_dec_len,
 | |
|     const float quant_max_bound,
 | |
|     const float quant_min_bound,
 | |
|     const float in_scale,
 | |
|     const int max_partition_size,
 | |
|     const int encoder_max_partition_size,
 | |
|     const int speculate_max_draft_token_num,
 | |
|     const bool causal,
 | |
|     const bool is_decoder,
 | |
|     const bool enable_prefill,
 | |
|     cudaStream_t& stream,
 | |
|     paddle::Tensor* out);
 | |
| 
 | |
| template <typename T, typename OutT, bool IsFP8 = false>
 | |
| void CascadeAppendAttentionC8Kernel(
 | |
|     const AppendAttnMetaData& meta_data,
 | |
|     const paddle::Tensor& qkv,  // [token_num, num_heads, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_k,  // [max_block_num, num_heads, block_size, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_v,  // [max_block_num, num_heads, head_dim, block_size]
 | |
|     const paddle::optional<paddle::Tensor>& attn_mask,
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         shift_bias,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         smooth_weight,  // [num_kv_heads, head_dim]
 | |
|     const paddle::Tensor& seq_lens_q,
 | |
|     const paddle::Tensor& seq_lens_kv,
 | |
|     const paddle::Tensor& seq_lens_encoder,
 | |
|     const paddle::Tensor& batch_id_per_token,
 | |
|     const paddle::Tensor& cu_seqlens_q,
 | |
|     const paddle::Tensor& block_table,
 | |
|     const paddle::Tensor& batch_ids,
 | |
|     const paddle::Tensor& tile_ids_per_batch,
 | |
|     const int num_blocks,
 | |
|     const int block_shape_q,
 | |
|     const int max_seq_len,
 | |
|     const int max_dec_len,
 | |
|     const float quant_max_bound,
 | |
|     const float quant_min_bound,
 | |
|     const float in_scale,
 | |
|     const int max_partition_size,
 | |
|     const int encoder_max_partition_size,
 | |
|     const int speculate_max_draft_token_num,
 | |
|     const bool causal,
 | |
|     const bool is_decoder,
 | |
|     const bool enable_prefill,
 | |
|     cudaStream_t& stream,
 | |
|     paddle::Tensor* out);
 | |
| 
 | |
| template <typename T, typename OutT>
 | |
| void CascadeAppendAttentionC4Kernel(
 | |
|     const AppendAttnMetaData& meta_data,
 | |
|     const paddle::Tensor& qkv,  // [token_num, num_heads, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_k,  // [max_block_num, num_heads, block_size, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_v,  // [max_block_num, num_heads, head_dim, block_size]
 | |
|     const paddle::optional<paddle::Tensor>& attn_mask,
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         shift_bias,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         smooth_weight,  // [num_kv_heads, head_dim]
 | |
|     const paddle::Tensor& seq_lens_q,
 | |
|     const paddle::Tensor& seq_lens_kv,
 | |
|     const paddle::Tensor& seq_lens_encoder,
 | |
|     const paddle::Tensor& batch_id_per_token,
 | |
|     const paddle::Tensor& cu_seqlens_q,
 | |
|     const paddle::Tensor& block_table,
 | |
|     const paddle::Tensor& batch_ids,
 | |
|     const paddle::Tensor& tile_ids_per_batch,
 | |
|     const int num_blocks,
 | |
|     const int block_shape_q,
 | |
|     const int max_seq_len,
 | |
|     const int max_dec_len,
 | |
|     const float quant_max_bound,
 | |
|     const float quant_min_bound,
 | |
|     const float in_scale,
 | |
|     const int max_partition_size,
 | |
|     const int encoder_max_partition_size,
 | |
|     const int speculate_max_draft_token_num,
 | |
|     const bool causal,
 | |
|     const bool is_decoder,
 | |
|     const bool enable_prefill,
 | |
|     cudaStream_t& stream,
 | |
|     paddle::Tensor* out);
 | |
| 
 | |
| template <typename T, typename OutT>
 | |
| void CascadeAppendAttentionKernel(
 | |
|     const AppendAttnMetaData& meta_data,
 | |
|     const paddle::Tensor& qkv,  // [token_num, num_heads, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_k,  // [max_block_num, num_heads, block_size, head_dim]
 | |
|     const paddle::Tensor&
 | |
|         cache_v,  // [max_block_num, num_heads, head_dim, block_size]
 | |
|     const paddle::optional<paddle::Tensor>& attn_mask,
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_scale,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_k_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         cache_v_zp,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         shift_bias,  // [num_kv_heads, head_dim]
 | |
|     const paddle::optional<paddle::Tensor>&
 | |
|         smooth_weight,  // [num_kv_heads, head_dim]
 | |
|     const paddle::Tensor& seq_lens_q,
 | |
|     const paddle::Tensor& seq_lens_kv,
 | |
|     const paddle::Tensor& seq_lens_encoder,
 | |
|     const paddle::Tensor& batch_id_per_token,
 | |
|     const paddle::Tensor& cu_seqlens_q,
 | |
|     const paddle::Tensor& block_table,
 | |
|     const paddle::Tensor& batch_ids,
 | |
|     const paddle::Tensor& tile_ids_per_batch,
 | |
|     const std::string& cache_quant_type_str,
 | |
|     const int num_blocks,
 | |
|     const int block_shape_q,
 | |
|     const int max_seq_len,
 | |
|     const int max_dec_len,
 | |
|     const float quant_max_bound,
 | |
|     const float quant_min_bound,
 | |
|     const float in_scale,
 | |
|     const int max_partition_size,
 | |
|     const int encoder_max_partition_size,
 | |
|     const int speculate_max_draft_token_num,
 | |
|     const bool causal,
 | |
|     const bool is_decoder,
 | |
|     const bool enable_prefill,
 | |
|     cudaStream_t& stream,
 | |
|     paddle::Tensor* out) {
 | |
|     if (cache_quant_type_str == "none") {
 | |
|         CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
 | |
|                                                 qkv,
 | |
|                                                 cache_k,
 | |
|                                                 cache_v,
 | |
|                                                 attn_mask,
 | |
|                                                 cache_k_scale,
 | |
|                                                 cache_v_scale,
 | |
|                                                 cache_k_zp,
 | |
|                                                 cache_v_zp,
 | |
|                                                 shift_bias,
 | |
|                                                 smooth_weight,
 | |
|                                                 seq_lens_q,
 | |
|                                                 seq_lens_kv,
 | |
|                                                 seq_lens_encoder,
 | |
|                                                 batch_id_per_token,
 | |
|                                                 cu_seqlens_q,
 | |
|                                                 block_table,
 | |
|                                                 batch_ids,
 | |
|                                                 tile_ids_per_batch,
 | |
|                                                 num_blocks,
 | |
|                                                 block_shape_q,
 | |
|                                                 max_seq_len,
 | |
|                                                 max_dec_len,
 | |
|                                                 quant_max_bound,
 | |
|                                                 quant_min_bound,
 | |
|                                                 in_scale,
 | |
|                                                 max_partition_size,
 | |
|                                                 encoder_max_partition_size,
 | |
|                                                 speculate_max_draft_token_num,
 | |
|                                                 causal,
 | |
|                                                 is_decoder,
 | |
|                                                 enable_prefill,
 | |
|                                                 stream,
 | |
|                                                 out);
 | |
|     } else if (cache_quant_type_str == "cache_int8") {
 | |
|         CascadeAppendAttentionC8Kernel<T, OutT>(meta_data,
 | |
|                                                 qkv,
 | |
|                                                 cache_k,
 | |
|                                                 cache_v,
 | |
|                                                 attn_mask,
 | |
|                                                 cache_k_scale,
 | |
|                                                 cache_v_scale,
 | |
|                                                 cache_k_zp,
 | |
|                                                 cache_v_zp,
 | |
|                                                 shift_bias,
 | |
|                                                 smooth_weight,
 | |
|                                                 seq_lens_q,
 | |
|                                                 seq_lens_kv,
 | |
|                                                 seq_lens_encoder,
 | |
|                                                 batch_id_per_token,
 | |
|                                                 cu_seqlens_q,
 | |
|                                                 block_table,
 | |
|                                                 batch_ids,
 | |
|                                                 tile_ids_per_batch,
 | |
|                                                 num_blocks,
 | |
|                                                 block_shape_q,
 | |
|                                                 max_seq_len,
 | |
|                                                 max_dec_len,
 | |
|                                                 quant_max_bound,
 | |
|                                                 quant_min_bound,
 | |
|                                                 in_scale,
 | |
|                                                 max_partition_size,
 | |
|                                                 encoder_max_partition_size,
 | |
|                                                 speculate_max_draft_token_num,
 | |
|                                                 causal,
 | |
|                                                 is_decoder,
 | |
|                                                 enable_prefill,
 | |
|                                                 stream,
 | |
|                                                 out);
 | |
|     } else if (cache_quant_type_str == "cache_fp8") {
 | |
|         CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
 | |
|                                                 qkv,
 | |
|                                                 cache_k,
 | |
|                                                 cache_v,
 | |
|                                                 attn_mask,
 | |
|                                                 cache_k_scale,
 | |
|                                                 cache_v_scale,
 | |
|                                                 cache_k_zp,
 | |
|                                                 cache_v_zp,
 | |
|                                                 shift_bias,
 | |
|                                                 smooth_weight,
 | |
|                                                 seq_lens_q,
 | |
|                                                 seq_lens_kv,
 | |
|                                                 seq_lens_encoder,
 | |
|                                                 batch_id_per_token,
 | |
|                                                 cu_seqlens_q,
 | |
|                                                 block_table,
 | |
|                                                 batch_ids,
 | |
|                                                 tile_ids_per_batch,
 | |
|                                                 num_blocks,
 | |
|                                                 block_shape_q,
 | |
|                                                 max_seq_len,
 | |
|                                                 max_dec_len,
 | |
|                                                 quant_max_bound,
 | |
|                                                 quant_min_bound,
 | |
|                                                 in_scale,
 | |
|                                                 max_partition_size,
 | |
|                                                 encoder_max_partition_size,
 | |
|                                                 speculate_max_draft_token_num,
 | |
|                                                 causal,
 | |
|                                                 is_decoder,
 | |
|                                                 enable_prefill,
 | |
|                                                 stream,
 | |
|                                                 out);
 | |
|     } else if (cache_quant_type_str == "cache_int4_zp") {
 | |
|         CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
 | |
|                                                 qkv,
 | |
|                                                 cache_k,
 | |
|                                                 cache_v,
 | |
|                                                 attn_mask,
 | |
|                                                 cache_k_scale,
 | |
|                                                 cache_v_scale,
 | |
|                                                 cache_k_zp,
 | |
|                                                 cache_v_zp,
 | |
|                                                 shift_bias,
 | |
|                                                 smooth_weight,
 | |
|                                                 seq_lens_q,
 | |
|                                                 seq_lens_kv,
 | |
|                                                 seq_lens_encoder,
 | |
|                                                 batch_id_per_token,
 | |
|                                                 cu_seqlens_q,
 | |
|                                                 block_table,
 | |
|                                                 batch_ids,
 | |
|                                                 tile_ids_per_batch,
 | |
|                                                 num_blocks,
 | |
|                                                 block_shape_q,
 | |
|                                                 max_seq_len,
 | |
|                                                 max_dec_len,
 | |
|                                                 quant_max_bound,
 | |
|                                                 quant_min_bound,
 | |
|                                                 in_scale,
 | |
|                                                 max_partition_size,
 | |
|                                                 encoder_max_partition_size,
 | |
|                                                 speculate_max_draft_token_num,
 | |
|                                                 causal,
 | |
|                                                 is_decoder,
 | |
|                                                 enable_prefill,
 | |
|                                                 stream,
 | |
|                                                 out);
 | |
|     } else {
 | |
|         PD_THROW(
 | |
|             "cache_quant_type_str should be one of [none, cache_int8, "
 | |
|             "cache_int4_zp]");
 | |
|     }
 | |
| }
 | 
