mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix c4 prompt_cache
This commit is contained in:
@@ -586,9 +586,9 @@ __global__ void append_cache_kv_c4(
|
||||
#pragma unroll
|
||||
for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) {
|
||||
cache_k_scale_smem[i] = cache_k_scale_now[i];
|
||||
cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast<T>(136.f);
|
||||
cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast<T>(136.f);
|
||||
cache_v_scale_smem[i] = cache_v_scale_now[i];
|
||||
cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast<T>(136.f);
|
||||
cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast<T>(136.f);
|
||||
}
|
||||
|
||||
smem_t k_smem(smem);
|
||||
@@ -640,25 +640,25 @@ __global__ void append_cache_kv_c4(
|
||||
convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]);
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
|
||||
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
|
||||
k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
|
||||
k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
|
||||
k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
|
||||
k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
|
||||
k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
|
||||
k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
|
||||
k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
|
||||
k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
|
||||
k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
|
||||
k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
|
||||
k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
|
||||
k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
|
||||
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
|
||||
k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
|
||||
k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
|
||||
k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
|
||||
k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
|
||||
k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
|
||||
k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
|
||||
k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
|
||||
k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
|
||||
k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
|
||||
k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
|
||||
k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
|
||||
k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
|
||||
}
|
||||
col_idx += 32;
|
||||
}
|
||||
@@ -711,36 +711,36 @@ __global__ void append_cache_kv_c4(
|
||||
convert_int4(frag_dq_T, v_frag[2 * i]);
|
||||
convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]);
|
||||
if (kv_idx < end_idx) {
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 16 < end_idx) {
|
||||
v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 17 < end_idx) {
|
||||
v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 24 < end_idx) {
|
||||
v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 25 < end_idx) {
|
||||
v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
kv_idx += 32;
|
||||
}
|
||||
@@ -956,6 +956,30 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
block_tables,
|
||||
cache_batch_ids,
|
||||
cache_tile_ids,
|
||||
cache_num_blocks,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
cache_quant_type,
|
||||
&k,
|
||||
&v,
|
||||
stream
|
||||
);
|
||||
}
|
||||
// write cache
|
||||
if (cache_quant_type == "none") {
|
||||
CascadeAppendWriteCacheKVQKV<data_t>(
|
||||
@@ -1038,30 +1062,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
block_tables,
|
||||
cache_batch_ids,
|
||||
cache_tile_ids,
|
||||
cache_num_blocks,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
cache_quant_type,
|
||||
&k,
|
||||
&v,
|
||||
stream
|
||||
);
|
||||
}
|
||||
return {q, k, v, qkv_out};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user