mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)
* remove padding_offsets from atten
This commit is contained in:
@@ -95,7 +95,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -121,7 +121,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
@@ -178,7 +178,7 @@ __global__ void prefill_absorb_cache_kernel(
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
@@ -204,11 +204,9 @@ __global__ void prefill_absorb_cache_kernel(
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
|
Reference in New Issue
Block a user