// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "helper.h" #include "utils.cuh" template void CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, const int block_shape_q, const int max_seq_len, const int max_dec_len, const float quant_max_bound, const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, const bool enable_prefill, cudaStream_t& stream, paddle::Tensor* out); template void CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, const int block_shape_q, const int max_seq_len, const int max_dec_len, const float quant_max_bound, const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, const bool enable_prefill, const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); template void CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, const int block_shape_q, const int max_seq_len, const int max_dec_len, const float quant_max_bound, const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, const bool enable_prefill, cudaStream_t& stream, paddle::Tensor* out); template void CascadeAppendAttentionKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] const paddle::optional& attn_mask, const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const std::string& cache_quant_type_str, const int num_blocks, const int block_shape_q, const int max_seq_len, const int max_dec_len, const float quant_max_bound, const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, const bool enable_prefill, cudaStream_t& stream, paddle::Tensor* out) { if (cache_quant_type_str == "none") { CascadeAppendAttentionC16Kernel(meta_data, qkv, cache_k, cache_v, attn_mask, cache_k_scale, cache_v_scale, cache_k_zp, cache_v_zp, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, seq_lens_encoder, batch_id_per_token, cu_seqlens_q, block_table, batch_ids, tile_ids_per_batch, num_blocks, block_shape_q, max_seq_len, max_dec_len, quant_max_bound, quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, is_decoder, enable_prefill, stream, out); // } else if (cache_quant_type_str == "cache_int8") { // CascadeAppendAttentionC8Kernel(meta_data, // qkv, // cache_k, // cache_v, // attn_mask, // cache_k_scale, // cache_v_scale, // cache_k_zp, // cache_v_zp, // shift_bias, // smooth_weight, // seq_lens_q, // seq_lens_kv, // seq_lens_encoder, // batch_id_per_token, // cu_seqlens_q, // block_table, // batch_ids, // tile_ids_per_batch, // num_blocks, // block_shape_q, // max_seq_len, // max_dec_len, // quant_max_bound, // quant_min_bound, // in_scale, // max_partition_size, // encoder_max_partition_size, // speculate_max_draft_token_num, // causal, // is_decoder, // enable_prefill, // cache_quant_type_str, // stream, // out); // } else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { // CascadeAppendAttentionC8Kernel(meta_data, // qkv, // cache_k, // cache_v, // attn_mask, // cache_k_scale, // cache_v_scale, // cache_k_zp, // cache_v_zp, // shift_bias, // smooth_weight, // seq_lens_q, // seq_lens_kv, // seq_lens_encoder, // batch_id_per_token, // cu_seqlens_q, // block_table, // batch_ids, // tile_ids_per_batch, // num_blocks, // block_shape_q, // max_seq_len, // max_dec_len, // quant_max_bound, // quant_min_bound, // in_scale, // max_partition_size, // encoder_max_partition_size, // speculate_max_draft_token_num, // causal, // is_decoder, // enable_prefill, // cache_quant_type_str, // stream, // out); // } else if (cache_quant_type_str == "cache_int4_zp") { // CascadeAppendAttentionC4Kernel(meta_data, // qkv, // cache_k, // cache_v, // attn_mask, // cache_k_scale, // cache_v_scale, // cache_k_zp, // cache_v_zp, // shift_bias, // smooth_weight, // seq_lens_q, // seq_lens_kv, // seq_lens_encoder, // batch_id_per_token, // cu_seqlens_q, // block_table, // batch_ids, // tile_ids_per_batch, // num_blocks, // block_shape_q, // max_seq_len, // max_dec_len, // quant_max_bound, // quant_min_bound, // in_scale, // max_partition_size, // encoder_max_partition_size, // speculate_max_draft_token_num, // causal, // is_decoder, // enable_prefill, // stream, // out); } else { PD_THROW( "cache_quant_type_str should be one of [none, cache_int8, " "cache_int4_zp]"); } }