[Cherry-Pick][CI]Fix write qknorm cache bug in speculative decoding(#5491) (#5617)

* [liuzichang spend 10 dyas]fix write qknorm cache bug

* fix 'fix cachekv bug''
This commit is contained in:
freeliuzc
2025-12-17 20:08:51 +08:00
committed by GitHub
parent d67b64d5e1
commit d7d633a285
2 changed files with 12 additions and 6 deletions

View File

@@ -31,6 +31,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
@@ -75,7 +76,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int ori_bi = batch_id_per_token[token_id];
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
if (seq_lens_decoder[ori_bi] == 0) continue;
if (seq_lens_encoder[ori_bi] > 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
@@ -87,7 +88,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
return; // NOTE(gongshaotian): For CUDAGraph padding
continue; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;
@@ -343,6 +344,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
@@ -380,7 +382,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int ori_bi = batch_id_per_token[token_id];
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
if (seq_lens_decoder[ori_bi] == 0) continue;
if (seq_lens_encoder[ori_bi] > 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
@@ -392,7 +394,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
return; // NOTE(gongshaotian): For CUDAGraph padding
continue; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;
@@ -473,6 +475,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
@@ -509,7 +512,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int token_id = linear_index / half_hidden_size;
const int ori_bi = batch_id_per_token[token_id];
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
if (seq_lens_decoder[ori_bi] == 0) continue;
if (seq_lens_encoder[ori_bi] > 0) continue;
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
@@ -521,7 +524,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
return; // NOTE(gongshaotian): For CUDAGraph padding
continue; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;

View File

@@ -67,6 +67,7 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
@@ -134,6 +135,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
@@ -158,6 +160,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,