From af49b81ffd63484bc57694e43efb429ae6baf4ab Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Mon, 8 Sep 2025 11:41:29 +0800 Subject: [PATCH] supports dynamic Cfp8 (#3767) * supports dynamic Cfp8 * add unittest --- custom_ops/gpu_ops/append_attention.cu | 4 +- .../append_attn/append_attention_c8_impl.cuh | 302 ++++++---- .../append_attn/append_attention_func.cuh | 199 ++++++- .../append_attn/append_attention_kernel.h | 5 +- .../decoder_write_cache_with_rope_impl.cuh | 288 ++++++++++ .../decoder_write_cache_with_rope_kernel.cu | 64 ++- .../encoder_write_cache_with_rope_impl.cuh | 520 ++++++++++++++++-- .../encoder_write_cache_with_rope_kernel.h | 4 +- .../append_attn/gqa_rope_write_cache.cu | 4 +- ...d_attention_c8_bfloat16_bfloat16_kernel.cu | 2 + ...append_attention_c8_bfloat16_fp8_kernel.cu | 2 + ...ppend_attention_c8_bfloat16_int8_kernel.cu | 2 + ...end_attention_c8_float16_float16_kernel.cu | 2 + .../append_attention_c8_float16_fp8_kerne.cu | 2 + .../append_attention_c8_float16_int8_kerne.cu | 2 + custom_ops/gpu_ops/append_attn/utils.cuh | 9 + .../layers/attention/append_attn_backend.py | 27 +- .../layers/quantization/kv_cache.py | 18 +- fastdeploy/worker/gpu_model_runner.py | 13 + tests/layers/test_append_attention.py | 173 ++++-- 20 files changed, 1417 insertions(+), 225 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 6af601dad..5e4ce35da 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -140,8 +140,8 @@ void AppendAttentionKernel( key_cache, value_cache, attn_mask, - cache_k_dequant_scales, - cache_v_dequant_scales, + cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales, + cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales, cache_k_zp, cache_v_zp, out_linear_shifts, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index addc24bd1..77ba87814 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -32,14 +32,15 @@ template + bool IsFP8 = false, + bool IsDynamicC8 = false> __global__ void multi_query_append_attention_c8_kernel( T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] - const T *__restrict__ cache_v_scale, // [num_kv_heads] + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] const int *__restrict__ seq_lens, @@ -91,28 +92,30 @@ __global__ void multi_query_append_attention_c8_kernel( return; } - T cache_k_scale_reg[num_frags_y * 4]; - T cache_v_scale_reg[num_frags_y * 2]; - if (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; + T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; } const uint32_t q_end = @@ -201,6 +204,13 @@ __global__ void multi_query_append_attention_c8_kernel( smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + T* k_smem_scale = nullptr; + T* v_smem_scale = nullptr; + if constexpr (IsDynamicC8) { + k_smem_scale = reinterpret_cast(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale = k_smem_scale + num_frags_z * 16; + } const uint32_t num_iterations = div_up( @@ -282,10 +292,22 @@ __global__ void multi_query_append_attention_c8_kernel( #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale( + k_smem_scale, + cache_k_scale_reg, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } wait_group<1>(); __syncthreads(); // s = qk - compute_qk_c8( + compute_qk_c8( &qo_smem, &q_smem_offset_r, &k_smem, @@ -318,6 +340,7 @@ __global__ void multi_query_append_attention_c8_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); + const int ori_kv_idx_base = kv_idx_base; kv_idx_base += num_frags_z * 16; produce_k_blockwise_c8( + v_smem_scale, + cache_v_scale_reg, + block_table_now, + cache_v_scale, + ori_kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } wait_group<1>(); __syncthreads(); @@ -346,7 +381,9 @@ __global__ void multi_query_append_attention_c8_kernel( BLOCK_SIZE, T, CacheT, - is_scale_channel_wise, IsFP8>( + is_scale_channel_wise, + IsFP8, + IsDynamicC8>( &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); __syncthreads(); @@ -463,14 +500,15 @@ template + bool IsFP8 = false, + bool IsDynamicC8 = false> __global__ void multi_query_append_attention_c8_warp1_4_kernel( T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] const int *__restrict__ seq_lens, @@ -522,28 +560,30 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( if (q_len <= 0) { return; } - T cache_k_scale_reg[num_frags_y * 4]; - T cache_v_scale_reg[num_frags_y * 2]; - if (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; + T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; } const uint32_t q_end = min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); @@ -634,6 +674,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + T* k_smem_scale = nullptr; + T* v_smem_scale = nullptr; + if constexpr (IsDynamicC8) { + k_smem_scale = reinterpret_cast(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16; + } const uint32_t num_iterations = div_up( CAUSAL @@ -716,11 +763,23 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale( + k_smem_scale, + cache_k_scale_reg, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } wait_group<1>(); __syncthreads(); // s = qk - compute_qk_c8( + compute_qk_c8( &qo_smem, &q_smem_offset_r, &k_smem, @@ -753,6 +812,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); + const uint32_t ori_kv_idx_base = kv_idx_base; kv_idx_base += NUM_WARP_KV * num_frags_z * 16; produce_k_blockwise_c8( + v_smem_scale, + cache_v_scale_reg, + block_table_now, + cache_v_scale, + ori_kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end + ); + } wait_group<1>(); __syncthreads(); @@ -781,7 +853,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( BLOCK_SIZE, T, CacheT, - is_scale_channel_wise, IsFP8>( + is_scale_channel_wise, + IsFP8, + IsDynamicC8>( &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); __syncthreads(); @@ -895,7 +969,8 @@ template + bool IsFP8 = false, + bool IsDynamicC8 = false> void MultiQueryAppendC8Attention( const AppendAttnMetaData &meta_data, const paddle::Tensor &qkv, @@ -953,7 +1028,8 @@ void MultiQueryAppendC8Attention( constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; constexpr uint32_t smem_size = num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; + num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + num_frags_z * 16 * sizeof(T) * 2; auto split_kv_kernel = multi_query_append_attention_c8_kernel; + false, + IsFP8, + IsDynamicC8>; if (is_scale_channel_wise) { split_kv_kernel = multi_query_append_attention_c8_kernel; + true, + IsFP8, + IsDynamicC8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, @@ -1022,7 +1102,9 @@ void MultiQueryAppendC8Attention( num_frags_y, OUT_NV_TYPE, ENABLE_PREFILL, - false, IsFP8>; + false, + IsFP8, + IsDynamicC8>; if (is_scale_channel_wise) { nosplit_kv_kernel = multi_query_append_attention_c8_kernel; + true, + IsFP8, + IsDynamicC8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, @@ -1218,7 +1302,8 @@ void MultiQueryAppendC8Attention( constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; constexpr uint32_t smem_size = num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; auto split_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel; + false, + IsFP8, + IsDynamicC8>; if (is_scale_channel_wise) { split_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel; + true, + IsFP8, + IsDynamicC8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, @@ -1295,7 +1384,9 @@ void MultiQueryAppendC8Attention( num_frags_y, OUT_NV_TYPE, ENABLE_PREFILL, - false, IsFP8>; + false, + IsFP8, + IsDynamicC8>; if (is_scale_channel_wise) { nosplit_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel; + true, + IsFP8, + IsDynamicC8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, @@ -1546,6 +1639,7 @@ void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out) { const auto token_num = meta_data.token_nums; @@ -1554,6 +1648,7 @@ void CascadeAppendAttentionC8Kernel( const auto num_heads = meta_data.q_num_heads; const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; const auto head_dim = meta_data.head_dims; + bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8"; DISPATCH_CAUSAL( causal, @@ -1572,43 +1667,46 @@ void CascadeAppendAttentionC8Kernel( BLOCK_SIZE, {DISPATCH_BLOCKSHAPE_Q( block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, { - MultiQueryAppendC8Attention( - meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale.get(), - cache_v_scale.get(), - shift_bias, - smooth_weight, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - is_decoder, - stream, - out); - })})})})})}) + DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, { + MultiQueryAppendC8Attention( + meta_data, + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale.get(), + cache_v_scale.get(), + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + is_decoder, + stream, + out); + })})})})})})}) } diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 146d0c30a..24787e8b7 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8( } } +template +__device__ __forceinline__ void produce_k_dynamic_scale( + T* k_smem_scale, + T* cache_k_reg, + const int* block_table_now, + const T* cache_k_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end +) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + if constexpr (NUM_WARP_Q == 4) { + // 4 warps shared block_size + const uint32_t tid = ty * 32 + tx; + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + if (tid < block_size) { + k_smem_scale[tid] = cache_k_scale_now[tid]; + } + __syncthreads(); + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8]; + } + } else { + // 1 warp 32 tokens + const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const int kv_idx_this_thread = kv_idx + ty * 32 + tx; + if (kv_idx_this_thread < chunk_end) { + k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; + } else { + k_smem_scale[ty * 32 + tx] = 0; + } + __syncwarp(); + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8]; + } + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale( + T* v_smem_scale, + T* cache_v_reg, + const int* block_table_now, + const T* cache_v_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end +) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + if constexpr (NUM_WARP_Q == 4) { + // 4 warps shared block_size + const uint32_t tid = ty * 32 + tx; + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + if (tid < block_size) { + v_smem_scale[tid] = cache_v_scale_now[tid]; + } + __syncthreads(); + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9]; + } + } else { + // 1 warp 32 tokens + const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const int kv_idx_this_thread = kv_idx + ty * 32 + tx; + if (kv_idx_this_thread < chunk_end) { + v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; + } else { + v_smem_scale[ty * 32 + tx] = 0; + } + __syncwarp(); + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9]; + } + } +} + template + bool IsFP8 = false, + bool IsDynamicC8 = false> __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, @@ -860,20 +968,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, convert_c8(b_frag_dq_T, b_frag[fy * 2]); convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); // scale zp - if constexpr (is_scale_channel_wise) { - const int scale_col = (ky * 2 + fy) * 4; - b_frag_dq_T[0] *= cache_k_scale[scale_col]; - b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; - b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; - b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; - b_frag_dq_T[4] *= cache_k_scale[scale_col]; - b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; - b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; - b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } } else { #pragma unroll for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_k_scale[0]; + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; } } #pragma unroll @@ -1093,7 +1208,9 @@ template + bool is_scale_channel_wise = false, + bool IsFP8 = false, + bool IsDynamicC8 = false> __device__ __forceinline__ void compute_sfm_v_c8( smem_t* v_smem, uint32_t* v_smem_offset_r, @@ -1135,16 +1252,28 @@ __device__ __forceinline__ void compute_sfm_v_c8( convert_c8(b_frag_dq_T, b_frag[fz * 2]); convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp - if constexpr (is_scale_channel_wise) { + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } } } else { -#pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[0]; - } + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 @@ -1171,7 +1300,9 @@ template + bool is_scale_channel_wise = false, + bool IsFP8 = false, + bool IsDynamicC8 = false> __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( smem_t* v_smem, uint32_t* v_smem_offset_r, @@ -1215,16 +1346,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( convert_c8(b_frag_dq_T, b_frag[fz * 2]); convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp - if constexpr (is_scale_channel_wise) { + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { + #pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } } } else { - #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[0]; - } + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index 8799c0a70..2cc069592 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel( causal, is_decoder, enable_prefill, + cache_quant_type_str, stream, out); - } else if (cache_quant_type_str == "cache_fp8") { + } else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { CascadeAppendAttentionC8Kernel(meta_data, qkv, cache_k, @@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel( causal, is_decoder, enable_prefill, + cache_quant_type_str, stream, out); } else if (cache_quant_type_str == "cache_int4_zp") { diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 45c9d0a02..2a56caa17 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -674,6 +674,294 @@ __global__ void append_decode_cache_T_neox_rope_kernel( } } +template +__global__ void append_decode_cache_int8_rope_qk_norm_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + T* __restrict__ cache_k_scale, + T* __restrict__ cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int kv_num_heads, + const bool rope_3d, + const float rms_norm_eps) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = cu_seqlens_q[bid]; + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + int cache_offset; + if (head_idx < num_heads) { + cache_offset = 0; + } else if (head_idx < num_heads + 2 * kv_num_heads) { + cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset; + } + T *cache_k_scale_now = cache_k_scale + cache_offset; + T *cache_v_scale_now = cache_v_scale + cache_offset; + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT out_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec[2 * i] = + static_cast(tmp1); + out_vec[2 * i + 1] = + static_cast(tmp2); + } + // qk norm + if (q_norm_weight) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + LoadOutScaleT q_norm_vec; + Load(&q_norm_weight[lane_id * VecSize], &q_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * q_norm_vec[i]); + } + } + Store(out_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * kv_num_heads) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + kv_num_heads) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + __syncwarp(); + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT out_vec1, out_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale = T(1.0f); + const int k_head_idx = head_idx - num_heads; + const int v_head_idx = head_idx - num_heads - kv_num_heads; + if (head_idx < num_heads + kv_num_heads) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + if (head_idx < num_heads + kv_num_heads) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec1[0] = + static_cast(tmp1); + out_vec1[1] = + static_cast(tmp2); + } else { + out_vec1[0] = src_vec1[0]; + out_vec1[1] = src_vec1[1]; + } + + // rope + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + if (head_idx < num_heads + kv_num_heads) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec2[0] = static_cast(tmp1); + out_vec2[1] = static_cast(tmp2); + } else { + out_vec2[0] = src_vec2[0]; + out_vec2[1] = src_vec2[1]; + } + if (k_norm_weight) { + if (head_idx < num_heads + kv_num_heads) { + LoadOutScaleT k_norm_vec1, k_norm_vec2; + Load(&k_norm_weight[head_bias], &k_norm_vec1); + Load(&k_norm_weight[head_bias + 8], &k_norm_vec2); + // qk norm + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + out_vec1[i] = static_cast(static_cast(out_vec1[i]) * row_inv_var * k_norm_vec1[i]); + out_vec2[i] = static_cast(static_cast(out_vec2[i]) * row_inv_var * k_norm_vec2[i]); + } + } + } + // reduce max, 1 head per warp + T local_max = -INFINITY; +#pragma unroll + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(out_vec1[i])); + local_max = __hmax(local_max, __habs(out_vec2[i])); + } +#pragma unroll + for (int m_offset = 16; m_offset > 1; m_offset /= 2) { + local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } + + scale = __hdiv(448, local_max); + + if (lane_id == 0) { + if (head_idx < num_heads + kv_num_heads) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } + } + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + cache_vec[i] = QuantToC8(scale, out_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, out_vec2[i], max_bound, min_bound); + } + if (head_idx < num_heads + kv_num_heads) { + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * kv_num_heads * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + const uint32_t base_tgt_cache_idx = + block_idx * kv_num_heads * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + template __global__ void append_decode_cache_int8_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index d6643ca20..c067efc75 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -553,9 +553,40 @@ void DecoderWriteCacheWithRoPEKernel( q_norm_weight ? q_norm_weight.get().data() : nullptr, k_norm_weight ? k_norm_weight.get().data() : nullptr, rms_norm_eps); + } else if (cache_quant_type_str == "block_wise_fp8") { + constexpr int num_warps = 4; + const int all_warps = + ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + append_decode_cache_int8_rope_qk_norm_kernel + <<>>( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast(cache_k_scale.get().data())), + const_cast(reinterpret_cast((cache_v_scale.get().data()))), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); } else { PD_THROW( - "append_decode_cache_rope_qk_norm not support cachekv quant yet"); + "append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8"); } } else { if (cache_quant_type_str == "none") { @@ -686,6 +717,37 @@ void DecoderWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d); + } else if (cache_quant_type_str == "block_wise_fp8") { + constexpr int num_warps = 4; + const int all_warps = + ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + append_decode_cache_int8_rope_qk_norm_kernel + <<>>( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast(cache_k_scale.get().data())), + const_cast(reinterpret_cast((cache_v_scale.get().data()))), + nullptr, + nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); } else if (cache_quant_type_str == "cache_int4_zp") { append_decode_cache_int4_rope( reinterpret_cast(qkv_ptr), diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 44489bae0..394fb8b34 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -1232,6 +1232,411 @@ __global__ void append_write_cache_kv_c8_qkv( } } +template +__global__ void append_write_cache_kv_c8_qkv_dynamic( + uint8_t *__restrict__ cache_k, + uint8_t *__restrict__ cache_v, + const T *__restrict__ qkv_input, + T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size] + T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size] + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ batch_id_per_token, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_tables, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t pad_len = BLOCK_SIZE; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const T cache_k_scale = cache_k_scales[kv_head_idx]; + const T cache_v_scale = cache_v_scales[kv_head_idx]; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids[btid]; + const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; + if (seq_len_this_time <= 0) { + return; + } + const int *block_table_now = nullptr; + + block_table_now = block_tables + batch_id * max_blocks_per_seq; + + const uint32_t num_rows_per_block = + NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE + const uint32_t start_len = seq_lens_decoder[batch_id]; + const uint32_t bf_pad_len = start_len % pad_len; + const uint32_t start_len_pad = start_len - bf_pad_len; + const uint32_t end_len = start_len + seq_len_this_time; + + const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block; + int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); + uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8; + + const uint32_t start_token_idx = cu_seqlens_q[batch_id]; + const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM; + const uint32_t kv_h_stride = HEAD_DIM; + __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T v_scale_smem[BLOCK_SIZE]; + if (tile_start >= start_len) { + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + // reset k + constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE; + constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; + uint32_t tgt_idx = + (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM + + tid % num_vecs_per_head_k * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_k; + block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { + Store(pad_cache_vec, + &cache_k[tgt_idx + block_i * HEAD_DIM]); + } + + // reset v + const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE; + const int num_token_each_time_v = 32 / num_vecs_per_head_v; + tgt_idx = + (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + + tid % num_vecs_per_head_v * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; + block_i += num_token_each_time_v) { + Store( + pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]); + } + } + smem_t k_smem(k_smem_ori); + smem_t v_smem(v_smem_ori); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp + + /* + 0 | 1 + 2 | 3 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS; + /* + 0 | 2 + 1 | 3 + */ + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, wid * num_frags_v * 2 + tid / 16); + + // load kv gmem to smem + const uint32_t real_start_token_idx = start_token_idx - bf_pad_len + + tile_id * num_rows_per_block + + wid * num_frags_z * 16 + tid / 8; + uint32_t k_read_idx = real_start_token_idx * kv_batch_stride + + (num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); + uint32_t v_read_idx = real_start_token_idx * kv_batch_stride + + (num_heads + kv_num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y / 4; + ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) + if (chunk_start >= start_len && chunk_start < end_len) { + k_smem.load_128b_async( + kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len); + v_smem.load_128b_async( + kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len); + } + kv_smem_offset_w = + k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); + k_read_idx += 8 * num_elems_per_128b(); + v_read_idx += 8 * num_elems_per_128b(); + } + kv_smem_offset_w = + k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) - + 2 * num_frags_y; + chunk_start += 4; + k_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + v_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + } + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // reduce scale + // 16 rows per warp + uint32_t kv_reduce_frag[4]; + T *kv_reduce_frag_T = reinterpret_cast(kv_reduce_frag); + + T k_local_max_value[num_frags_z * 2]; + T v_local_max_value[num_frags_z * 2]; +#pragma unroll + for (int i = 0; i < num_frags_z * 2; i++) { + k_local_max_value[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < num_frags_z * 2; i++) { + v_local_max_value[i] = -INFINITY; + } + const int num_kv_heads = gridDim.z; + const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE; + T *cache_k_scale_now = cache_k_scales + scale_offset; + T *cache_v_scale_now = cache_v_scales + scale_offset; + // k scale +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // reduce per thread, 4 threads each row + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); +#pragma unroll + for (int i = 0; i < 4; i++) { + k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]); + } + k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + // reduce per row + for (int i = 0; i < 2; i++) { + T local_max_value = __habs(k_local_max_value[fz * 2 + i]); + local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2)); + local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1)); + // used for quant + k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); + } + // store + if (tid % 4 == 0) { + const int offset_now = wid * num_frags_z * 16 + tid / 4; + // used for dequant + if (tile_start + offset_now >= start_len) { + if (tile_start + offset_now < end_len) { + cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]); + } else { + cache_k_scale_now[offset_now] = 0; + } + } + if (tile_start + offset_now + 8 >= start_len) { + if (tile_start + offset_now + 8 < end_len) { + cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]); + } else { + cache_k_scale_now[offset_now + 8] = 0; + } + } + } + __syncthreads(); + k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 + } + // v scale + #pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // reduce per thread, 4 threads each row + v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); +#pragma unroll + for (int i = 0; i < 4; i++) { + v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]); + } + k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + // reduce per row + for (int i = 0; i < 2; i++) { + T local_max_value = __habs(v_local_max_value[fz * 2 + i]); + local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2)); + local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1)); + v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); + } + // store + if (tid % 4 == 0) { + const int offset_now = wid * num_frags_z * 16 + tid / 4; + // used for dequant + if (tile_start + offset_now >= start_len) { + if (tile_start + offset_now < end_len) { + cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]); + v_scale_smem[offset_now] = v_local_max_value[fz * 2]; + } else { + cache_v_scale_now[offset_now] = 0; + v_scale_smem[offset_now] = 0; + } + } + if (tile_start + offset_now + 8 >= start_len) { + if (tile_start + offset_now + 8 < end_len) { + cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]); + v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1]; + } else { + cache_v_scale_now[offset_now + 8] = 0; + v_scale_smem[offset_now + 8] = 0; + } + } + } + __syncthreads(); + k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 + } + __syncthreads(); + + // mask, quant, store + using LoadKVT = AlignedVector; + LoadKVT cache_vec1; + LoadKVT cache_vec2; + + uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; + uint32_t kv_frag[4]; + const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t write_b_stride = HEAD_DIM; + const uint32_t write_d_stride = BLOCK_SIZE; + uint32_t k_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_z * 16 + tid / 4) * write_b_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t k_write_idx_now = k_write_idx_now_z + + fy % 2 * 8 * write_b_stride + + fy / 2 * 32; // + fy % 2 * 16; + // load + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); + // quant + T *k_frag_T = reinterpret_cast(kv_frag); + if (bf_pad_len != 0) { + Load(cache_k + k_write_idx_now, &cache_vec1); + Load(cache_k + k_write_idx_now + 16, &cache_vec2); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 8; ++v_id) { + uint8_t uint_quant_value; + if (chunk_start_k + (v_id / 4) * 8 >= start_len && + chunk_start_k + (v_id / 4) * 8 < end_len) { + uint_quant_value = QuantToC8(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f); + } else { + uint_quant_value = 0; + } + if (bf_pad_len != 0) { + if (v_id < 4) { + cache_vec1[v_id] |= uint_quant_value; + } else { + cache_vec2[v_id % 4] |= uint_quant_value; + } + } else { + if (v_id < 4) { + cache_vec1[v_id] = uint_quant_value; + } else { + cache_vec2[v_id - 4] = uint_quant_value; + } + } + } + // store + Store(cache_vec1, cache_k + k_write_idx_now); + Store(cache_vec2, cache_k + k_write_idx_now + 16); + k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) - + 2 * num_frags_y; + chunk_start_k += 16; + } + + uint32_t chunk_start_v = tile_start + tid % 4 * 2; + uint32_t v_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_v * 16 + tid / 4) * write_d_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit + const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS; + T v_scales[num_frags_z_v * 4]; + for (int v_i = 0; v_i < num_frags_z_v; v_i++) { + const int offset = v_i * 16; + const int t_offset = tid % 4 * 2; + v_scales[v_i * 4] = v_scale_smem[offset + t_offset]; + v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1]; + v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8]; + v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9]; + } + +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_v; ++fy) { + uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) { + uint32_t v_write_idx_now = v_write_idx_now_v + + fz % 2 * 8 * write_d_stride + + fz / 2 * 32; // + fz % 2 * 16; + // load + v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag); + // quant + T *v_frag_T = reinterpret_cast(kv_frag); + if (bf_pad_len != 0) { + Load(cache_v + v_write_idx_now, &cache_vec1); + Load(cache_v + v_write_idx_now + 16, &cache_vec2); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 8; ++v_id) { + uint8_t uint_quant_value; + if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && + chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { + uint_quant_value = QuantToC8(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f); + // store now + } else { + uint_quant_value = 0; + } + if (bf_pad_len != 0) { + if (v_id < 4) { + cache_vec1[v_id] |= uint_quant_value; + } else { + cache_vec2[v_id % 4] |= uint_quant_value; + } + } else { + if (v_id < 4) { + cache_vec1[v_id] = uint_quant_value; + } else { + cache_vec2[v_id % 4] = uint_quant_value; + } + } + } + // store + Store(cache_vec1, cache_v + v_write_idx_now); + Store(cache_vec2, cache_v + v_write_idx_now + 16); + chunk_start_v += 16; + v_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r); + } + v_smem_offset_r = k_smem.advance_offset_by_column<2>( + v_smem_offset_r, wid * num_frags_v + fy) - + 16 * num_frags_z_v * num_vecs_per_head; + chunk_start_v -= 16 * num_frags_z_v; + } +} + // Write Cache KV in Append template ::type; auto max_blocks_per_seq = meta_data.max_blocks_per_seq; auto num_tokens = meta_data.token_nums; auto num_heads = meta_data.q_num_heads; @@ -2027,49 +2433,77 @@ void CascadeAppendWriteCacheKVC8QKV( dim3 blocks(32, num_warps); const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2; - auto kernel_fn = append_write_cache_kv_c8_qkv; - if (is_fp8) { - kernel_fn = append_write_cache_kv_c8_qkv; + if (cache_quant_type != "block_wise_fp8") { + auto kernel_fn = append_write_cache_kv_c8_qkv; + if (cache_quant_type == "cache_fp8") { + kernel_fn = append_write_cache_kv_c8_qkv; + } + if (is_scale_channel_wise) { + kernel_fn = append_write_cache_kv_c8_qkv; + } + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel_fn<<>>(cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); + } else { + auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic; + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel_fn<<>>(cache_k_out->data(), + cache_v_out->data(), + reinterpret_cast(qkv.data()), + const_cast(reinterpret_cast(cache_k_scale.data())), + const_cast(reinterpret_cast(cache_v_scale.data())), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } - if (is_scale_channel_wise) { - kernel_fn = append_write_cache_kv_c8_qkv; - } - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - qkv.data(), - cache_k_scale.data(), - cache_v_scale.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); } template diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 5af84e73f..b0d66a291 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -167,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel( stream, key_cache_out, value_cache_out); - } else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") { + } else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { CascadeAppendWriteCacheKVC8QKV( @@ -187,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel( num_blocks, max_seq_len, is_scale_channel_wise, - cache_quant_type_str == "cache_fp8", + cache_quant_type_str, stream, key_cache_out, value_cache_out); diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 3b33c750a..2a28bc94f 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -1000,7 +1000,7 @@ std::vector GQARopeWriteCacheKernel( stream, const_cast(&key_cache), const_cast(&value_cache)); - } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { + } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") { CascadeAppendWriteCacheKVC8QKV( meta_data, *const_cast(&key_cache), @@ -1018,7 +1018,7 @@ std::vector GQARopeWriteCacheKernel( kv_num_blocks_data, max_seq_len, false, // is_scale_channel_wise - cache_quant_type == "cache_fp8", // is_fp8 + cache_quant_type, stream, const_cast(&key_cache), const_cast(&value_cache)); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu index e860a0462..757cccaf9 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu @@ -56,6 +56,7 @@ CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -103,5 +104,6 @@ CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu index 3b61ecd16..54b0b0be4 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu @@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu index 4d7b11d99..153b81ee0 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu @@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 13874a3f9..12d86dade 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -441,6 +441,15 @@ __forceinline__ __host__ __device__ void vec_cast( PD_THROW("not support the group_size", group_size); \ } +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + #define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 59fe071af..aa47aa391 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -231,6 +231,17 @@ class AppendAttentionBackend(AttentionBackend): metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, ) + cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") + if cache_quant_type_str == "block_wise_fp8": + cache_k = forward_meta.caches[4 * layer.layer_id] + cache_v = forward_meta.caches[4 * layer.layer_id + 1] + cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] + cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] + else: + cache_k = forward_meta.caches[2 * layer.layer_id] + cache_v = forward_meta.caches[2 * layer.layer_id + 1] + cache_k_scales = getattr(layer, "cache_k_scale", None) + cache_v_scales = getattr(layer, "cache_v_scale", None) if self.use_output: quant_max_bound = getattr(layer, "quant_max_bound", 0.0) @@ -269,8 +280,8 @@ class AppendAttentionBackend(AttentionBackend): append_attention_with_output( qkv, - forward_meta.caches[2 * layer.layer_id], - forward_meta.caches[2 * layer.layer_id + 1], + cache_k, + cache_v, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -293,8 +304,8 @@ class AppendAttentionBackend(AttentionBackend): metadata.attn_mask, layer.qkv_bias, layer.qkv_scale, - getattr(layer, "cache_k_scale", None), - getattr(layer, "cache_v_scale", None), + cache_k_scales, + cache_v_scales, getattr(layer, "cache_k_out_scale", None), getattr(layer, "cache_v_out_scale", None), getattr(layer, "cache_k_zp", None), @@ -325,8 +336,8 @@ class AppendAttentionBackend(AttentionBackend): else: res = append_attention( qkv, - forward_meta.caches[2 * layer.layer_id], - forward_meta.caches[2 * layer.layer_id + 1], + cache_k, + cache_v, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -348,8 +359,8 @@ class AppendAttentionBackend(AttentionBackend): metadata.attn_mask, layer.qkv_bias, layer.qkv_scale, - getattr(layer, "cache_k_scale", None), - getattr(layer, "cache_v_scale", None), + cache_k_scales, + cache_v_scales, getattr(layer, "cache_k_out_scale", None), getattr(layer, "cache_v_out_scale", None), getattr(layer, "cache_k_zp", None), diff --git a/fastdeploy/model_executor/layers/quantization/kv_cache.py b/fastdeploy/model_executor/layers/quantization/kv_cache.py index d560e6122..d7727da5c 100644 --- a/fastdeploy/model_executor/layers/quantization/kv_cache.py +++ b/fastdeploy/model_executor/layers/quantization/kv_cache.py @@ -33,6 +33,7 @@ class KvCacheQuantzationTypes(str, Enum): INT8 = "int8" FP8 = "float8_e4m3fn" + BLOCK_WISE_FP8 = "block_wise_fp8" INT8_ZP = "int8_zp" INT4_ZP = "int4_zp" FP8_ZP = "float8_e4m3fn_zp" @@ -62,7 +63,11 @@ class KvCacheQuantConfig(QuantConfigBase): if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP: self.max_bound = 127.0 - elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP: + elif ( + self.quant_type == KvCacheQuantzationTypes.FP8 + or self.quant_type == KvCacheQuantzationTypes.FP8_ZP + or self.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8 + ): self.max_bound = 448.0 elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP: self.max_bound = 7.0 @@ -178,12 +183,17 @@ class KVCacheMethodBase(QuantMethodBase): layer.cache_quant_type_str = "cache_int4_zp" layer.quant_max_bound = 7.0 layer.quant_min_bound = -7.0 + elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8: + layer.cache_quant_type_str = "block_wise_fp8" + layer.quant_max_bound = 448.0 + layer.quant_min_bound = -448.0 else: raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented") - self.load_scale(layer, state_dict) - if self.cache_quant_config.has_zero_point: - self.load_zp(layer, state_dict) + if "block_wise" not in layer.cache_quant_type_str: + self.load_scale(layer, state_dict) + if self.cache_quant_config.has_zero_point: + self.load_zp(layer, state_dict) def apply(self, layer): """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 50044b4e8..9579f93b2 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1023,6 +1023,8 @@ class GPUModelRunner(ModelRunnerBase): kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) + if kv_cache_quant_type == "block_wise_fp8": + kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): @@ -1050,6 +1052,17 @@ class GPUModelRunner(ModelRunnerBase): fill_value=0, dtype=cache_type, ) + if kv_cache_quant_type == "block_wise_fp8": + cache_kvs[f"key_cache_scales_{i}"] = paddle.full( + shape=kv_cache_scale_shape, + fill_value=0, + dtype=paddle.get_default_dtype(), + ) + cache_kvs[f"value_cache_scales_{i}"] = paddle.full( + shape=kv_cache_scale_shape, + fill_value=0, + dtype=paddle.get_default_dtype(), + ) self.share_inputs["caches"] = list(cache_kvs.values()) for value in cache_kvs.values(): del value diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index ee35443a4..8616437ab 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import time import unittest @@ -20,6 +21,7 @@ import paddle from paddle.incubate.nn.functional import fused_rms_norm paddle.seed(10) +np.random.seed(10) class RopeEmbedding: @@ -334,7 +336,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.name = "TestAppendGroupQueryAttnWithRope" self.place = paddle.CUDAPlace(0) self.batch_size = 1 - self.q_num_head = 12 + self.q_num_head = 16 self.kv_num_head = 2 self.seq_len = 64 self.max_dec_len = 64 @@ -347,9 +349,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.max_seq_len = self.seq_len + self.max_dec_len self.softmax_scale = self.dim_head**-0.5 self.rope_theta = 10000 - self.dtype = "float16" + self.dtype = "bfloat16" self.use_qk_norm = True self.use_mask_offset = False + self.use_dynamic_quant = False self.init_tensor() def init_tensor(self): @@ -391,8 +394,23 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): ) self.scale = 1.0 / np.sqrt(self.dim_head) - self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) - self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + if self.use_dynamic_quant: + self.cache_scale_shape = ( + self.max_block_num, + self.kv_num_head, + self.blocksize, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + else: + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.key_cache_scale = None + self.value_cache_scale = None self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") for i in range(self.batch_size): need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize @@ -415,6 +433,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None): paddle.disable_static() + print("use_dynamic_quant: ", self.use_dynamic_quant) self.token_num = self.seq_len * self.batch_size q, k, v, qkv = get_qkv_and_qkv_concat_tensor( self.batch_size, @@ -472,18 +491,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.blocksize, speculate_max_draft_token_num + 1, ) + if self.use_dynamic_quant: + cache_quant_type = "block_wise_fp8" + else: + cache_quant_type = "none" - # Warm up - WARM_UP = 1 - RUN_TIME = 2 - for i in range(WARM_UP + RUN_TIME): - if i == WARM_UP: - paddle.device.synchronize() - start_time = time.time() - out = append_attention( - qkv, - self.cache_k, - self.cache_v, + if self.use_dynamic_quant: + qkv_copy = copy.deepcopy(qkv) + append_attention( + qkv_copy, + self.cache_k_T, + self.cache_v_T, self.seq_lens_encoder, self.seq_lens_decoder, self.seq_lens_this_time, @@ -519,7 +537,69 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): k_norm_weight, # k_norm_weight 1e-6, "fp16", - "none", # cache_quant_type + "none", + self.use_neox_rotary_style, + False, + self.max_seq_len, + 0.0, # quant_min_bound + 0.0, # quant_max_bound + -1, # out_linear_in_scale + 64, # encoder_block_shape_q + 16, # decoder_block_shape_q + 32768, # max_partition_size + 32768, # encoder_max_partition_size + speculate_max_draft_token_num + 1, # speculate_max_draft_token_num + True, # causal + False, # speculate_decoder + ) + + # Warm up + WARM_UP = 1 + RUN_TIME = 2 + for i in range(WARM_UP + RUN_TIME): + if i == WARM_UP: + paddle.device.synchronize() + start_time = time.time() + out = append_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + self.decoder_batch_ids, + self.decoder_tile_ids_per_batch, + self.decoder_num_blocks_cpu, + self.max_len_tensor_cpu, + max_len_kv, + self.rope_emb, # rope_emb + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + self.key_cache_scale, # cache_k_quant_scales + self.value_cache_scale, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + self.mask_offset, # mask_offset + None, # kv_signal_data + q_norm_weight, # q_norm_weight + k_norm_weight, # k_norm_weight + 1e-6, + "fp16", + cache_quant_type, self.use_neox_rotary_style, False, self.max_seq_len, @@ -537,13 +617,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): paddle.device.synchronize() end_time = time.time() print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms") - naive_cache_k, naive_cache_v = block_cache_to_naive_cache( - self.cache_k, - self.cache_v, - self.batch_size, - self.block_tables, - self.seq_len, - ) np.testing.assert_allclose( out.numpy(), out_.numpy(), @@ -572,13 +645,22 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): if self.use_mask_offset: print("encoder mask_offset: ", self.mask_offset) self.cmp_append_attention(attn_mask=self.attention_mask) - naive_cache_k, naive_cache_v = block_cache_to_naive_cache( - self.cache_k, - self.cache_v, - self.batch_size, - self.block_tables, - self.seq_len, - ) + if self.use_dynamic_quant: + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k_T, + self.cache_v_T, + self.batch_size, + self.block_tables, + self.seq_len, + ) + else: + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) # decoder self.seq_lens_decoder[:] = self.seq_lens_encoder self.seq_lens_encoder[:] = 0 @@ -613,10 +695,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope): def setUp(self): paddle.disable_static() - self.name = "TestAppendGroupQueryAttnWithRope" + self.name = "TestAppendGroupQueryAttnWithNeoXRope" self.place = paddle.CUDAPlace(0) self.batch_size = 1 - self.q_num_head = 12 + self.q_num_head = 16 self.kv_num_head = 2 self.seq_len = 64 self.max_dec_len = 64 @@ -632,6 +714,33 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope): self.dtype = "float16" self.use_qk_norm = False self.use_mask_offset = True + self.use_dynamic_quant = False + self.init_tensor() + + +class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope): + def setUp(self): + paddle.disable_static() + self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8" + self.place = paddle.CUDAPlace(0) + self.batch_size = 1 + self.q_num_head = 16 + self.kv_num_head = 2 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 128 + self.q_hid_dim = self.q_num_head * self.dim_head + self.kv_hid_dim = self.kv_num_head * self.dim_head + self.blocksize = 64 + self.use_neox_rotary_style = False + # max_seq_len = self.seq_len + self.max_dec_len + self.max_seq_len = self.seq_len + self.max_dec_len + self.softmax_scale = self.dim_head**-0.5 + self.rope_theta = 10000 + self.dtype = "bfloat16" + self.use_qk_norm = True + self.use_mask_offset = False + self.use_dynamic_quant = True self.init_tensor()