diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 3de883daa..2ba7555e7 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -46,7 +46,7 @@ std::vector AppendAttentionKernel( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& encoder_batch_ids, @@ -165,7 +165,7 @@ std::vector AppendAttentionKernel( seq_lens_this_time, seq_lens_decoder, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, lambda_batch_ids, @@ -202,7 +202,7 @@ std::vector AppendAttentionKernel( seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, kv_batch_ids, @@ -274,7 +274,7 @@ std::vector AppendAttentionKernel( qkv, // [token_num, num_heads, head_dim] seq_lens_decoder, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, rotary_embs, @@ -297,7 +297,7 @@ std::vector AppendAttentionKernel( qkv_out, // [token_num, num_heads, head_dim] seq_lens_decoder, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, rotary_embs, @@ -322,7 +322,7 @@ std::vector AppendAttentionKernel( qkv, // [token_num, num_heads, head_dim] seq_lens_decoder, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, rotary_embs, @@ -346,7 +346,7 @@ std::vector AppendAttentionKernel( qkv_out, // [token_num, num_heads, head_dim] seq_lens_decoder, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, rotary_embs, @@ -403,7 +403,7 @@ std::vector AppendAttention( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& encoder_batch_ids, @@ -473,7 +473,7 @@ std::vector AppendAttention( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, encoder_batch_ids, @@ -550,7 +550,7 @@ std::vector> AppendAttentionInferShape( const std::vector& seq_lens_encoder_shape, const std::vector& seq_lens_decoder_shape, const std::vector& seq_lens_this_time_shape, - const std::vector& padding_offsets_shape, + const std::vector& batch_id_per_token_shape, const std::vector& cu_seqlens_q_shape, const std::vector& block_tables_shape, const std::vector& encoder_batch_ids_shape, @@ -610,7 +610,7 @@ std::vector AppendAttentionInferDtype( const paddle::DataType& seq_lens_encoder_dtype, const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_this_time_dtype, - const paddle::DataType& padding_offsets_dtype, + const paddle::DataType& batch_id_per_token_dtype, const paddle::DataType& cu_seqlens_q_dtype, const paddle::DataType& block_tables_dtype, const paddle::DataType& encoder_batch_ids_dtype, @@ -688,7 +688,7 @@ PD_BUILD_STATIC_OP(append_attention) "seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", - "padding_offsets", + "batch_id_per_token", "cu_seqlens_q", "block_tables", "encoder_batch_ids", diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index fb40e56e4..b7d8441c6 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -773,7 +773,7 @@ void MultiQueryAppendAttention( const paddle::Tensor &seq_lens_q, const paddle::Tensor &seq_lens_kv, const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const paddle::Tensor &batch_ids, @@ -1007,7 +1007,8 @@ void MultiQueryAppendAttention( seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), shift_bias ? reinterpret_cast( const_cast(shift_bias.get().data())) : nullptr, @@ -1240,7 +1241,8 @@ void MultiQueryAppendAttention( seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), shift_bias ? reinterpret_cast( const_cast(shift_bias.get().data())) : nullptr, @@ -1287,7 +1289,7 @@ void CascadeAppendAttentionC16Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -1350,7 +1352,7 @@ void CascadeAppendAttentionC16Kernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 4349612d5..9f003af88 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -960,7 +960,7 @@ void MultiQueryAppendC4Attention( const paddle::Tensor &seq_lens_q, const paddle::Tensor &seq_lens_kv, const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const paddle::Tensor &batch_ids, @@ -1219,7 +1219,8 @@ void MultiQueryAppendC4Attention( seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), shift_bias ? reinterpret_cast( const_cast(shift_bias.get().data())) : nullptr, @@ -1477,7 +1478,8 @@ void MultiQueryAppendC4Attention( seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), shift_bias ? reinterpret_cast( const_cast(shift_bias.get().data())) : nullptr, @@ -1524,7 +1526,7 @@ void CascadeAppendAttentionC4Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -1591,7 +1593,7 @@ void CascadeAppendAttentionC4Kernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index 01fb581e3..3b72597e0 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -897,7 +897,7 @@ void MultiQueryAppendC8Attention( const paddle::Tensor &seq_lens_q, const paddle::Tensor &seq_lens_kv, const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const paddle::Tensor &batch_ids, @@ -1179,7 +1179,8 @@ void MultiQueryAppendC8Attention( seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), shift_bias ? reinterpret_cast( const_cast(shift_bias.get().data())) : nullptr, @@ -1450,7 +1451,8 @@ void MultiQueryAppendC8Attention( seq_lens_q.data(), seq_lens_kv.data(), seq_lens_encoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), shift_bias ? reinterpret_cast( const_cast(shift_bias.get().data())) : nullptr, @@ -1497,7 +1499,7 @@ void CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -1562,7 +1564,7 @@ void CascadeAppendAttentionC8Kernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 2eb94e44a..8b6802d27 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -1852,7 +1852,7 @@ __global__ void merge_multi_chunks_kernel( const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] const int* __restrict__ seq_lens_q, const int* __restrict__ seq_lens_kv, - const int* __restrict__ padding_offsets, + const int* __restrict__ batch_id_per_token, const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] T* __restrict__ out, @@ -1866,8 +1866,7 @@ __global__ void merge_multi_chunks_kernel( const int head_dim) { const int vid = threadIdx.x, hid = threadIdx.y; const int qid = blockIdx.x; - const uint32_t ori_token_id = qid + padding_offsets[qid]; - const uint32_t bid = ori_token_id / max_seq_len; + const uint32_t bid = batch_id_per_token[qid]; if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) { return; } @@ -2240,7 +2239,8 @@ __global__ void merge_multi_chunks_v2_kernel( const int *__restrict__ seq_lens_q, const int *__restrict__ seq_lens_kv, const int *__restrict__ seq_lens_encoder, - const int *__restrict__ padding_offsets, + const int *__restrict__ batch_id_per_token, + const int *__restrict__ cu_seqlens_q, const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] OutT *__restrict__ out, @@ -2259,9 +2259,8 @@ __global__ void merge_multi_chunks_v2_kernel( __shared__ T smem[bdy * HEAD_DIM]; __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { - const uint32_t ori_token_id = qid + padding_offsets[qid]; - const uint32_t bid = ori_token_id / max_seq_len; - const uint32_t local_seq_id = ori_token_id % max_seq_len; + const uint32_t bid = batch_id_per_token[qid]; + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) continue; int seq_len_kv = seq_lens_kv[bid]; diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index 59e502f2a..8799c0a70 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -40,7 +40,7 @@ void CascadeAppendAttentionC16Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -85,7 +85,7 @@ void CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -130,7 +130,7 @@ void CascadeAppendAttentionC4Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -175,7 +175,7 @@ void CascadeAppendAttentionKernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -211,7 +211,7 @@ void CascadeAppendAttentionKernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, @@ -246,7 +246,7 @@ void CascadeAppendAttentionKernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, @@ -281,7 +281,7 @@ void CascadeAppendAttentionKernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, @@ -316,7 +316,7 @@ void CascadeAppendAttentionKernel( seq_lens_q, seq_lens_kv, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_table, batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu b/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu index 098d3c6f4..701ba42df 100644 --- a/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu @@ -317,7 +317,7 @@ void MultiQueryDecoderAttention( const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const int max_seq_len, @@ -483,7 +483,7 @@ void DecodeMLAAttentionKernel( const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, @@ -513,7 +513,7 @@ void DecodeMLAAttentionKernel( {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, {MultiQueryDecoderAttention( - meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q, + meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q, block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); } @@ -527,7 +527,7 @@ template void DecodeMLAAttentionKernel( const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, @@ -548,7 +548,7 @@ template void DecodeMLAAttentionKernel( const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index f3915caed..67066efc2 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -28,7 +28,7 @@ __global__ void append_decode_cache_T_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -134,7 +134,7 @@ __global__ void append_decode_cache_T_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -254,7 +254,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -366,7 +366,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -498,7 +498,7 @@ __global__ void append_decode_cache_int8_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -745,7 +745,7 @@ __global__ void append_decode_cache_int8_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1047,7 +1047,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1346,7 +1346,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1739,7 +1739,7 @@ __global__ void append_decode_cache_int4_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2034,7 +2034,7 @@ __global__ void append_decode_cache_int4_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2362,7 +2362,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2732,7 +2732,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index cb62e9048..fe72d120a 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -21,7 +21,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, T* value_cache, T* qkv_out, const int* block_tables, - const int* padding_offsets, + const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -57,7 +57,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -79,7 +79,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -102,7 +102,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -125,7 +125,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -149,7 +149,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* padding_offsets, + const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -182,7 +182,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -207,7 +207,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -232,7 +232,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -257,7 +257,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -282,7 +282,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* padding_offsets, + const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -317,7 +317,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -344,7 +344,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -371,7 +371,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -398,7 +398,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -424,7 +424,7 @@ void DecoderWriteCacheWithRoPEKernel( const paddle::Tensor& qkv, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -471,7 +471,7 @@ void DecoderWriteCacheWithRoPEKernel( reinterpret_cast(value_cache_out->data()), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -503,7 +503,7 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -536,7 +536,7 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -570,7 +570,7 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -603,7 +603,7 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(const_cast(qkv_out->data())), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -650,7 +650,7 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -677,7 +677,7 @@ DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -703,7 +703,7 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -729,7 +729,7 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h index 602f5a007..c25f68211 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h @@ -23,7 +23,7 @@ void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 0d2dc165e..a2da51ef1 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -23,7 +23,8 @@ __global__ void VariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const float *qkv_out_scales, // [3, num_head, dim_head] @@ -52,8 +53,7 @@ __global__ void VariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -61,7 +61,7 @@ __global__ void VariableLengthRotaryKernel( const int hi = qkv_bias / last_dim; const int h_bias = qkv_bias % last_dim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias; @@ -107,7 +107,8 @@ __global__ void VariableLengthRotaryKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, T *qkv_out, @@ -130,8 +131,7 @@ __global__ void VariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -139,7 +139,7 @@ __global__ void VariableLengthRotaryKernel( const int hi = qkv_bias / last_dim; const int h_bias = qkv_bias % last_dim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t base_idx = token_idx * 3 * hidden_size + @@ -167,7 +167,8 @@ __global__ void NeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const float *qkv_out_scales, // [3, num_head, dim_head] @@ -199,8 +200,7 @@ __global__ void NeoxVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -208,7 +208,7 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hi = qkv_bias / half_lastdim; const int h_bias = qkv_bias % half_lastdim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; const int bias_idx_left = @@ -261,7 +261,8 @@ __global__ void NeoxVariableLengthRotaryKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, T *qkv_out, @@ -285,8 +286,7 @@ __global__ void NeoxVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -294,7 +294,7 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hi = qkv_bias / half_lastdim; const int h_bias = qkv_bias % half_lastdim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; const int base_idx_left = token_idx * 3 * full_hidden_size + @@ -327,7 +327,8 @@ __global__ void GQAVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const float *qkv_out_scales, // [3, q_num_head, dim_head] @@ -357,14 +358,13 @@ __global__ void GQAVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx];; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t bias_idx = hi * last_dim + h_bias; @@ -410,7 +410,8 @@ __global__ void GQAVariableLengthRotaryKernel( const T *qkv, const float *cos_emb, const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, T *qkv_out, @@ -434,14 +435,13 @@ __global__ void GQAVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx];; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t base_idx = @@ -472,7 +472,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, const float *qkv_out_scales, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const T *qkv_biases, @@ -504,15 +505,13 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - int ori_seq_id; - ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t bias_idx = hi * last_dim + h_bias; @@ -561,7 +560,8 @@ template __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const T *qkv_biases, @@ -590,15 +590,13 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - int ori_seq_id; - ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t bias_idx = hi * last_dim + h_bias; @@ -645,7 +643,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const float *qkv_out_scales, // [3, q_num_head, dim_head] @@ -676,14 +675,13 @@ __global__ void GQANeoxVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / half_lastdim; const int h_bias = bias % half_lastdim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; const int bias_idx_left = hi * last_dim + h_bias; @@ -736,7 +734,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const T *qkv, const float *cos_emb, const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const float *qkv_out_scales, @@ -761,14 +760,13 @@ __global__ void GQANeoxVariableLengthRotaryKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / half_lastdim; const int h_bias = bias % half_lastdim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; const int base_idx_left = @@ -805,7 +803,8 @@ __global__ void cache_kernel( T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size] const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int *__restrict__ padding_offsets, // [num_tokens] + const int *__restrict__ batch_id_per_token, // [num_tokens] + const int *__restrict__ cu_seqlens_q, // [bsz] const int *__restrict__ seq_lens, // [bsz] const int *__restrict__ seq_lens_decoder, // [bsz] const int max_seq_len, @@ -831,11 +830,9 @@ __global__ void cache_kernel( const uint32_t qkv_bias = bias % hidden_size; const uint32_t hi = qkv_bias / head_size; const uint32_t h_bias = qkv_bias % head_size; - const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx]; - const uint32_t ori_bi = ori_token_idx / max_seq_len; + const uint32_t ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; - const uint32_t ori_seq_id = - ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi]; + const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int32_t *block_table_now = nullptr; @@ -878,7 +875,7 @@ __global__ void append_write_cache_kv_c8_qkv( const int *__restrict__ tile_ids, const int *__restrict__ seq_lens_this_time, const int *__restrict__ seq_lens_decoder, - const int *__restrict__ padding_offsets, + const int *__restrict__ batch_id_per_token, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_tables, const int max_seq_len, @@ -1117,7 +1114,7 @@ __global__ void append_write_cache_kv_c4_qkv( const int *__restrict__ tile_ids, const int *__restrict__ seq_lens_this_time, const int *__restrict__ seq_lens_decoder, - const int *__restrict__ padding_offsets, + const int *__restrict__ batch_id_per_token, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_tables, const int max_seq_len, @@ -1405,7 +1402,8 @@ void rotary_qk_variable( const float *qkv_out_scales, // [3, num_head, dim_head] const T *qkv_bias, const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const int token_num, @@ -1437,7 +1435,8 @@ void rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out_scales, @@ -1453,7 +1452,8 @@ void rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out, @@ -1471,7 +1471,8 @@ void rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out_scales, @@ -1487,7 +1488,8 @@ void rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out, @@ -1506,7 +1508,8 @@ void gqa_rotary_qk_variable( const float *qkv_out_scales, // [3, num_head, dim_head] const T *qkv_bias, const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const int token_num, @@ -1541,7 +1544,8 @@ void gqa_rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out_scales, @@ -1559,7 +1563,8 @@ void gqa_rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out, @@ -1579,7 +1584,8 @@ void gqa_rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out_scales, @@ -1596,7 +1602,8 @@ void gqa_rotary_qk_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_out_scales, @@ -1620,7 +1627,8 @@ void gqa_rotary_qk_quant_variable( const T *cache_k_scales, const T *cache_v_scales, const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const int token_num, @@ -1652,7 +1660,8 @@ void gqa_rotary_qk_quant_variable( cos_emb, sin_emb, qkv_out_scales, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_bias, @@ -1671,7 +1680,8 @@ void gqa_rotary_qk_quant_variable( reinterpret_cast(qkv_input), cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens, seq_lens_decoder, qkv_bias, @@ -1697,7 +1707,8 @@ void CascadeAppendWriteCacheKVQKV( &qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * // kv_num_heads, head_dim] if GQA) const paddle::Tensor &block_table, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const int max_seq_len, @@ -1723,7 +1734,8 @@ void CascadeAppendWriteCacheKVQKV( reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), block_table.data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), max_seq_len, @@ -1747,7 +1759,7 @@ void CascadeAppendWriteCacheKVC8QKV( const paddle::Tensor &cache_v_scale, // [num_kv_heads, head_dim] const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const paddle::Tensor &batch_ids, @@ -1812,7 +1824,7 @@ void CascadeAppendWriteCacheKVC8QKV( tile_ids_per_batch.data(), seq_lens_this_time.data(), seq_lens_decoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), block_table.data(), max_seq_len, @@ -1835,7 +1847,7 @@ void CascadeAppendWriteCacheKVC4QKV( const paddle::Tensor &cache_v_zp, // [num_kv_heads, head_dim] const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, const paddle::Tensor &batch_ids, @@ -1882,7 +1894,7 @@ void CascadeAppendWriteCacheKVC4QKV( tile_ids_per_batch.data(), seq_lens_this_time.data(), seq_lens_decoder.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), block_table.data(), max_seq_len, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 18850a1fc..5eb238216 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -25,7 +25,7 @@ void EncoderWriteCacheWithRopeKernel( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& batch_ids, @@ -63,7 +63,8 @@ void EncoderWriteCacheWithRopeKernel( qkv_out_scales ? qkv_out_scales.get().data() : nullptr, qkv_biases ? qkv_biases.get().data() : nullptr, rotary_embs.get().data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), token_num, @@ -82,7 +83,8 @@ void EncoderWriteCacheWithRopeKernel( qkv_out_scales ? qkv_out_scales.get().data() : nullptr, qkv_biases ? qkv_biases.get().data() : nullptr, rotary_embs.get().data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), token_num, @@ -103,7 +105,8 @@ void EncoderWriteCacheWithRopeKernel( cache_k_scale ? cache_k_scale.get().data() : nullptr, cache_v_scale ? cache_v_scale.get().data() : nullptr, rotary_embs.get().data(), - padding_offsets.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), token_num, @@ -123,7 +126,8 @@ void EncoderWriteCacheWithRopeKernel( CascadeAppendWriteCacheKVQKV(meta_data, *qkv_out, block_tables, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens_encoder, seq_lens_decoder, max_seq_len, @@ -142,7 +146,7 @@ void EncoderWriteCacheWithRopeKernel( cache_v_scale.get(), seq_lens_this_time, seq_lens_decoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, batch_ids, @@ -169,7 +173,7 @@ void EncoderWriteCacheWithRopeKernel( cache_v_zp.get(), seq_lens_this_time, seq_lens_decoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 541fb0f39..f63f36a6b 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -25,7 +25,8 @@ __global__ void GQAVariableLengthRotarySplitKernel( const T *qkv, const float *cos_emb, const float *sin_emb, - const int *padding_offsets, + const int *batch_id_per_token, + const int *cu_seqlens_q, const int *seq_lens, const int *seq_lens_decoder, const int *cu_seqlens_k, @@ -52,14 +53,13 @@ __global__ void GQAVariableLengthRotarySplitKernel( linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_token_idx = token_idx + padding_offsets[token_idx]; - const int ori_bi = ori_token_idx / seq_len; + const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; @@ -108,9 +108,10 @@ void gqa_rotary_qk_split_variable( T *v, const T *qkv_input, const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] - const int *padding_offsets, + const int *batch_id_per_token, const int *seq_lens_encoder, const int *seq_lens_decoder, + const int *cu_seqlens_q, const int *cu_seqlens_k, const int token_num, const int num_heads, @@ -133,7 +134,8 @@ void gqa_rotary_qk_split_variable( qkv_input, cos_emb, sin_emb, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens_encoder, seq_lens_decoder, cu_seqlens_k, @@ -421,7 +423,7 @@ std::vector GQARopeWriteCacheKernel( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, const paddle::Tensor& kv_batch_ids, const paddle::Tensor& kv_tile_ids, @@ -492,9 +494,10 @@ std::vector GQARopeWriteCacheKernel( v.data(), qkv.data(), rotary_embs.data(), - padding_offsets.data(), + batch_id_per_token.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), + cu_seqlens_q.data(), cu_seqlens_k.data(), token_num, num_heads, @@ -509,7 +512,8 @@ std::vector GQARopeWriteCacheKernel( meta_data, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, + cu_seqlens_q, seq_lens_encoder, seq_lens_decoder, max_seq_len, @@ -526,7 +530,7 @@ std::vector GQARopeWriteCacheKernel( cache_v_quant_scales.get(), seq_lens_this_time, seq_lens_decoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, kv_batch_ids, @@ -593,7 +597,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache) "seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder", - "padding_offsets", + "batch_id_per_token", "block_tables", "kv_batch_ids", "kv_tile_ids_per_batch", diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu index 75c7b4076..d2ee6bd73 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu @@ -22,7 +22,7 @@ std::vector PrefillMLAWriteCache( const paddle::Tensor& kv_pe, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const int max_seq_len, @@ -53,7 +53,7 @@ std::vector PrefillMLAWriteCache( reinterpret_cast(const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_decoder.data(), @@ -73,7 +73,7 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& kv_cache, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const std::string& cache_quant_type_str, @@ -99,7 +99,7 @@ std::vector PrefillMLAWriteCacheKernel( kv_pe, seq_lens, seq_lens_decoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, max_seq_len, @@ -112,7 +112,7 @@ std::vector PrefillMLAWriteCacheKernel( kv_pe, seq_lens, seq_lens_decoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, max_seq_len, @@ -130,7 +130,7 @@ std::vector DecodeMLAWriteCache( const paddle::Tensor& kv_pe, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const int max_seq_len, @@ -164,7 +164,7 @@ std::vector DecodeMLAWriteCache( reinterpret_cast(const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -205,7 +205,7 @@ std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_cache, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const std::string& cache_quant_type_str, @@ -232,7 +232,7 @@ std::vector DecodeMLAWriteCacheKernel( kv_pe, seq_lens, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, max_seq_len, @@ -246,7 +246,7 @@ std::vector DecodeMLAWriteCacheKernel( kv_pe, seq_lens, seq_lens_encoder, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, max_seq_len, @@ -265,7 +265,7 @@ PD_BUILD_OP(prefill_mla_write_cache) "kv_cache", "seq_lens", "seq_lens_decoder", - "padding_offsets", + "batch_id_per_token", "cu_seqlens_q", "block_tables"}) .Outputs({"kv_cache_out"}) @@ -280,7 +280,7 @@ PD_BUILD_OP(decode_mla_write_cache) "kv_cache", "seq_lens", "seq_lens_encoder", - "padding_offsets", + "batch_id_per_token", "cu_seqlens_q", "block_tables"}) .Outputs({"kv_cache_out"}) diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh index 31c73d2f4..2efcb7a8c 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh @@ -95,7 +95,7 @@ __global__ void speculate_decode_absorb_cache_kernel( T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, // nope_size] const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, + const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -121,7 +121,7 @@ __global__ void speculate_decode_absorb_cache_kernel( linear_index < elem_cnt; linear_index += step) { const int token_id = linear_index / hidden_size; - const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len; + const int ori_bi = batch_id_per_token[token_id]; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int start_token_idx = cu_seqlens_q[ori_bi]; @@ -178,7 +178,7 @@ __global__ void prefill_absorb_cache_kernel( T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, // nope_size] const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, + const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_decoder, // [bsz] @@ -204,11 +204,9 @@ __global__ void prefill_absorb_cache_kernel( linear_index += step) { const uint32_t token_idx = linear_index / hidden_size; const uint32_t bias = linear_index % hidden_size; - const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx]; - const uint32_t ori_bi = ori_token_idx / max_seq_len; + const uint32_t ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; - const uint32_t ori_seq_id = - ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi]; + const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int* block_table_now = nullptr; block_table_now = block_tables + ori_bi * max_blocks_per_seq; diff --git a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h index 3d5b8c8fd..4d81b99a7 100644 --- a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h @@ -26,7 +26,7 @@ void DecodeMLAAttentionKernel( const paddle::optional& smooth_weight, const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, int max_seq_len, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index f5e6f8db4..ed8952ad5 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -26,7 +26,7 @@ __global__ void append_clear_cache_int8_block( // block_size, head_size // 2] const int* __restrict__ seq_lens, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_encoder, // [bsz] const int max_seq_len, @@ -41,8 +41,8 @@ __global__ void append_clear_cache_int8_block( const int wid = tid / 32; const int lane_id = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + padding_offsets[token_id]; - const int bid = ori_token_id / max_seq_len; + + const int bid = batch_id_per_token[token_id]; const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -100,7 +100,7 @@ __global__ void append_clear_cache_int4_block( // block_size, head_size // 2] const int* __restrict__ seq_lens, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_encoder, // [bsz] const int max_seq_len, @@ -115,8 +115,8 @@ __global__ void append_clear_cache_int4_block( const int wid = tid / 32; const int lane_id = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + padding_offsets[token_id]; - const int bid = ori_token_id / max_seq_len; + + const int bid = batch_id_per_token[token_id]; const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -178,7 +178,7 @@ __global__ void append_speculate_cache_rope_kernel( // head_size // 2] T* __restrict__ q_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] const float* __restrict__ cos_emb, @@ -214,7 +214,7 @@ __global__ void append_speculate_cache_rope_kernel( linear_index < elem_cnt; linear_index += step) { const int token_id = linear_index / hidden_size; - const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len; + const int ori_bi = batch_id_per_token[token_id]; if (seq_lens_decoder[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v @@ -311,7 +311,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] const float* __restrict__ cos_emb, @@ -347,7 +347,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( linear_index < elem_cnt; linear_index += step) { const int token_id = linear_index / half_hidden_size; - const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len; + const int ori_bi = batch_id_per_token[token_id]; if (seq_lens_decoder[ori_bi] == 0) continue; const int bias = linear_index % half_hidden_size; const int hi = bias / half_head_size; // q + k + v @@ -458,7 +458,7 @@ __global__ void append_speculate_cache_int8_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -484,8 +484,8 @@ __global__ void append_speculate_cache_int8_rope_kernel( const int wid = tid / 32; const int lane_id = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + padding_offsets[token_id]; - const int bid = ori_token_id / max_seq_len; + + const int bid = batch_id_per_token[token_id]; const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -690,7 +690,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -716,8 +716,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const int wid = tid / 32; const int lane_id = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + padding_offsets[token_id]; - const int bid = ori_token_id / max_seq_len; + + const int bid = batch_id_per_token[token_id]; const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -1068,7 +1068,7 @@ __global__ void append_speculate_cache_int4_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1097,8 +1097,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( const int lane_id = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + padding_offsets[token_id]; - const int bid = ori_token_id / max_seq_len; + + const int bid = batch_id_per_token[token_id]; const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -1374,7 +1374,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1403,8 +1403,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const int lane_id = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + padding_offsets[token_id]; - const int bid = ori_token_id / max_seq_len; + + const int bid = batch_id_per_token[token_id]; const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index a4fbfe69d..b7c533a38 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -22,7 +22,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, T* value_cache, T* qkv_out, const int* block_tables, - const int* padding_offsets, + const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -59,7 +59,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, cos_emb, @@ -82,7 +82,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, cos_emb, @@ -106,7 +106,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* padding_offsets, + const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -136,7 +136,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, value_cache, seq_lens, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens_encoder, max_seq_len, @@ -151,7 +151,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -175,7 +175,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -201,7 +201,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* padding_offsets, + const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -233,7 +233,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, value_cache, seq_lens, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens_encoder, max_seq_len, @@ -248,7 +248,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -274,7 +274,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - padding_offsets, + batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -301,7 +301,7 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::Tensor& qkv, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -349,7 +349,7 @@ void SpeculateWriteCacheWithRoPEKernel( reinterpret_cast(value_cache_out->data()), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -376,7 +376,7 @@ void SpeculateWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -409,7 +409,7 @@ void SpeculateWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -442,7 +442,7 @@ void SpeculateWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(const_cast(qkv_out->data())), block_tables.data(), - padding_offsets.data(), + batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -488,7 +488,7 @@ template void SpeculateWriteCacheWithRoPEKernel( // gqa_group_size, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -514,7 +514,7 @@ SpeculateWriteCacheWithRoPEKernel( // gqa_group_size, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -539,7 +539,7 @@ template void SpeculateWriteCacheWithRoPEKernel( // gqa_group_size, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -566,7 +566,7 @@ SpeculateWriteCacheWithRoPEKernel( // gqa_group_size, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h index f687328cc..bb192f5a9 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h @@ -23,7 +23,7 @@ void SpeculateWriteCacheWithRoPEKernel( // gqa_group_size, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu index f3bc4761c..93db78513 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu @@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu index ab52a6ac0..436250238 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu index a77fb1fb5..daaad4de6 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu index 5217f04da..923f9b0d3 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu index 2b0eb912f..888c410bb 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu index 1fc66ee97..656374937 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu @@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu index 6aa7e8689..fba62df2b 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu index d30d803ea..e860a0462 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu @@ -38,7 +38,7 @@ CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -85,7 +85,7 @@ CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu index 4a1072003..3b61ecd16 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu index 1364fd65b..4d7b11d99 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu @@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, @@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel( const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu index 7012ac314..8d786ce58 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu @@ -22,7 +22,7 @@ EncoderWriteCacheWithRopeKernel( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu index 31e5629f4..a34da8258 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu @@ -21,7 +21,7 @@ template void EncoderWriteCacheWithRopeKernel( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu index afa712b52..42f07ee8b 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu @@ -21,7 +21,7 @@ template void EncoderWriteCacheWithRopeKernel( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu index 663c0f693..ef3d3832e 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu @@ -21,7 +21,7 @@ template void EncoderWriteCacheWithRopeKernel( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::Tensor& batch_ids, diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index d61f0fd07..38bd4b67f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -54,7 +54,7 @@ std::vector AppendAttention( const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids, const paddle::Tensor &encoder_tile_ids_per_batch, const paddle::Tensor &encoder_num_blocks, @@ -94,7 +94,7 @@ std::vector GQARopeWriteCacheKernel( const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &padding_offsets, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids, const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks, const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids, @@ -330,7 +330,7 @@ std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_cache, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const std::string& cache_quant_type_str, @@ -343,7 +343,7 @@ std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_cache, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const std::string& cache_quant_type_str, @@ -369,7 +369,7 @@ std::vector MultiHeadLatentAttention( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& padding_offsets, + const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, const paddle::Tensor& encoder_batch_ids, const paddle::Tensor& encoder_tile_ids_per_batch, diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 2e1152e42..92c252f17 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -46,7 +46,7 @@ __global__ void GetPaddingOffsetKernel(int *padding_offset, const int ti = threadIdx.x; int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { - padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; + padding_offset[bi * max_seq_len - cum_offset + i] = bi; } if (ti == 0) { cum_offsets_out[bi] = cum_offset; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu index 49eeb5a6a..2fbfff160 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu @@ -53,7 +53,7 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset, const int ti = threadIdx.x; int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { - padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; + padding_offset[bi * max_seq_len - cum_offset + i] = bi; } if (ti == 0) { cum_offsets_out[bi] = cum_offset; diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 17ab2e9ad..15395b419 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -85,8 +85,8 @@ class ForwardMeta(): # Accumulated offset cum_offsets: Optional[paddle.Tensor] = None - # Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids - padding_offset: Optional[paddle.Tensor] = None + # batch_id_per_token tensor, used to indicate which token belongs which batch after padding removal to the original input_ids + batch_id_per_token: Optional[paddle.Tensor] = None # Accumulated sequence length of query cu_seqlens_q: Optional[paddle.Tensor] = None # Accumulated sequence length of key diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index f41deb62e..7da552d70 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -216,7 +216,7 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - forward_meta.padding_offset, + forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, metadata.encoder_batch_ids, diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index 516c1a647..979e8fd64 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -32,7 +32,7 @@ def append_attention( seq_lens_encoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_this_time: paddle.Tensor, - padding_offsets: paddle.Tensor, + batch_id_per_token: paddle.Tensor, cu_seqlens_q: paddle.Tensor, block_tables: paddle.Tensor, encoder_batch_ids: paddle.Tensor, @@ -86,7 +86,7 @@ def append_attention( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, - padding_offsets, + batch_id_per_token, cu_seqlens_q, block_tables, encoder_batch_ids, diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 4c3337b4d..a92d946ee 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -72,7 +72,7 @@ def pre_process( Return: ids_remove_padding: cum_offsets: - padding_offset: + batch_id_per_token: cu_seqlens_q: cu_seqlens_k: """ @@ -85,7 +85,7 @@ def pre_process( ( ids_remove_padding, cum_offsets, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k, ) = speculate_get_padding_offset( @@ -115,12 +115,12 @@ def pre_process( ( ids_remove_padding, cum_offsets, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k, ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) - return (ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, + return (ids_remove_padding, cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, output_cum_offsets, output_padding_offset) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index a5e7f600f..68eafa9b8 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -272,8 +272,8 @@ class MTPProposer(Proposer): self.main_model_inputs["ids_remove_padding"]) self.model_inputs["cum_offsets"] = paddle.clone( self.main_model_inputs["cum_offsets"]) - self.model_inputs["padding_offset"] = paddle.clone( - self.main_model_inputs["padding_offset"]) + self.model_inputs["batch_id_per_token"] = paddle.clone( + self.main_model_inputs["batch_id_per_token"]) self.model_inputs["cu_seqlens_q"] = paddle.clone( self.main_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_k"] = paddle.clone( @@ -447,7 +447,7 @@ class MTPProposer(Proposer): seq_lens_decoder=self.model_inputs["seq_lens_decoder"], seq_lens_this_time=self.model_inputs["seq_lens_this_time"], cum_offsets=self.model_inputs["cum_offsets"], - padding_offset=self.model_inputs["padding_offset"], + batch_id_per_token=self.model_inputs["batch_id_per_token"], cu_seqlens_q=self.model_inputs["cu_seqlens_q"], cu_seqlens_k=self.model_inputs["cu_seqlens_k"], block_tables=self.model_inputs["block_tables"], @@ -542,7 +542,7 @@ class MTPProposer(Proposer): ( ids_remove_padding, cum_offsets, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k, output_cum_offsets, @@ -560,8 +560,8 @@ class MTPProposer(Proposer): self.model_inputs["ids_remove_padding"].copy_( ids_remove_padding, False) self.model_inputs["cum_offsets"].copy_(cum_offsets, False) - self.model_inputs["padding_offset"].copy_( - padding_offset, False) + self.model_inputs["batch_id_per_token"].copy_( + batch_id_per_token, False) self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # for speculative decoding diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4d6bac01d..66e08f03e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -559,7 +559,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype='int32') - self.share_inputs["padding_offset"] = paddle.full([max_num_seqs, 1], + self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype='int32') self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], @@ -670,7 +670,7 @@ class GPUModelRunner(ModelRunnerBase): ( ids_remove_padding, cum_offsets, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k, output_cum_offsets, @@ -685,7 +685,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.share_inputs["cum_offsets"].copy_(cum_offsets, False) - self.share_inputs["padding_offset"].copy_(padding_offset, False) + self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -762,7 +762,7 @@ class GPUModelRunner(ModelRunnerBase): seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], cum_offsets=self.share_inputs["cum_offsets"], - padding_offset=self.share_inputs["padding_offset"], + batch_id_per_token=self.share_inputs["batch_id_per_token"], cu_seqlens_q=self.share_inputs["cu_seqlens_q"], cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"],