diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index bf0a22b6e..30d3f9196 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -31,6 +31,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] const float* __restrict__ cos_emb, const float* __restrict__ sin_emb, const float* @@ -75,7 +76,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int ori_bi = batch_id_per_token[token_id]; if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding - if (seq_lens_decoder[ori_bi] == 0) continue; + if (seq_lens_encoder[ori_bi] > 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v const int h_bias = bias % head_size; @@ -87,7 +88,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - return; // NOTE(gongshaotian): For CUDAGraph padding + continue; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -343,6 +344,7 @@ __global__ void append_speculate_cache_rope_kernel( const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] const float* __restrict__ cos_emb, const float* __restrict__ sin_emb, const float* @@ -380,7 +382,7 @@ __global__ void append_speculate_cache_rope_kernel( const int ori_bi = batch_id_per_token[token_id]; if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding - if (seq_lens_decoder[ori_bi] == 0) continue; + if (seq_lens_encoder[ori_bi] > 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v const int h_bias = bias % head_size; @@ -392,7 +394,7 @@ __global__ void append_speculate_cache_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - return; // NOTE(gongshaotian): For CUDAGraph padding + continue; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -473,6 +475,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] const float* __restrict__ cos_emb, const float* __restrict__ sin_emb, const float* @@ -509,7 +512,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int token_id = linear_index / half_hidden_size; const int ori_bi = batch_id_per_token[token_id]; if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding - if (seq_lens_decoder[ori_bi] == 0) continue; + if (seq_lens_encoder[ori_bi] > 0) continue; const int bias = linear_index % half_hidden_size; const int hi = bias / half_head_size; // q + k + v const int h_bias = bias % half_head_size; @@ -521,7 +524,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - return; // NOTE(gongshaotian): For CUDAGraph padding + continue; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index 3a9305df2..513f384b2 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -67,6 +67,7 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, batch_id_per_token, cu_seqlens_q, seq_lens, + seq_lens_encoder, cos_emb, sin_emb, qkv_out_scales, @@ -134,6 +135,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, batch_id_per_token, cu_seqlens_q, seq_lens, + seq_lens_encoder, cos_emb, sin_emb, qkv_out_scales, @@ -158,6 +160,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, batch_id_per_token, cu_seqlens_q, seq_lens, + seq_lens_encoder, cos_emb, sin_emb, qkv_out_scales,