[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)

* remove padding_offsets from atten
This commit is contained in:
周周周
2025-07-17 18:41:31 +08:00
committed by GitHub
parent d49f8fb30a
commit ddb10ac509
50 changed files with 311 additions and 288 deletions

View File

@@ -95,7 +95,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -121,7 +121,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
@@ -178,7 +178,7 @@ __global__ void prefill_absorb_cache_kernel(
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
@@ -204,11 +204,9 @@ __global__ void prefill_absorb_cache_kernel(
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
const uint32_t ori_bi = ori_token_idx / max_seq_len;
const uint32_t ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;