From 238766e4032fab11340da7dfb7b2ace83158b2fa Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Mon, 28 Jul 2025 11:11:51 +0800 Subject: [PATCH] fix c4 prompt_cache --- .../append_attn/gqa_rope_write_cache.cu | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 739a0d9c6..20e8b147e 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -586,9 +586,9 @@ __global__ void append_cache_kv_c4( #pragma unroll for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { cache_k_scale_smem[i] = cache_k_scale_now[i]; - cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast(136.f); + cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast(136.f); cache_v_scale_smem[i] = cache_v_scale_now[i]; - cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast(136.f); + cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast(136.f); } smem_t k_smem(smem); @@ -640,25 +640,25 @@ __global__ void append_cache_kv_c4( convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]); if (row_idx < end_idx) { - k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx]; - k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1]; - k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8]; - k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9]; - k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16]; - k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17]; - k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24]; - k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25]; + k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx]; + k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1]; + k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8]; + k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9]; + k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16]; + k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17]; + k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24]; + k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25]; } if (row_idx + 8 < end_idx) { - k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx]; - k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1]; - k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8]; - k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9]; - k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16]; - k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17]; - k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24]; - k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25]; + k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx]; + k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1]; + k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8]; + k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9]; + k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16]; + k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17]; + k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24]; + k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25]; } col_idx += 32; } @@ -711,36 +711,36 @@ __global__ void append_cache_kv_c4( convert_int4(frag_dq_T, v_frag[2 * i]); convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]); if (kv_idx < end_idx) { - v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 1 < end_idx) { - v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 8 < end_idx) { - v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 9 < end_idx) { - v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 16 < end_idx) { - v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 17 < end_idx) { - v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 24 < end_idx) { - v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 25 < end_idx) { - v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; - v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; + v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; } kv_idx += 32; } @@ -956,6 +956,30 @@ std::vector GQARopeWriteCacheKernel( rotary_embs.dims()[2], head_dim, stream); + + if (token_num < kv_token_num) { + AppendCacheKV( + key_cache, + value_cache, + cache_k_dequant_scales.get(), + cache_v_dequant_scales.get(), + cache_k_zp.get(), + cache_v_zp.get(), + seq_lens_this_time, + seq_lens_decoder, + cu_seqlens_k, + block_tables, + cache_batch_ids, + cache_tile_ids, + cache_num_blocks, + max_blocks_per_seq, + kv_num_heads, + cache_quant_type, + &k, + &v, + stream + ); + } // write cache if (cache_quant_type == "none") { CascadeAppendWriteCacheKVQKV( @@ -1038,30 +1062,6 @@ std::vector GQARopeWriteCacheKernel( } } } - - if (token_num < kv_token_num) { - AppendCacheKV( - key_cache, - value_cache, - cache_k_dequant_scales.get(), - cache_v_dequant_scales.get(), - cache_k_zp.get(), - cache_v_zp.get(), - seq_lens_this_time, - seq_lens_decoder, - cu_seqlens_k, - block_tables, - cache_batch_ids, - cache_tile_ids, - cache_num_blocks, - max_blocks_per_seq, - kv_num_heads, - cache_quant_type, - &k, - &v, - stream - ); - } return {q, k, v, qkv_out}; }