mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)
* remove padding_offsets from atten
This commit is contained in:
@@ -1852,7 +1852,7 @@ __global__ void merge_multi_chunks_kernel(
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
T* __restrict__ out,
|
||||
@@ -1866,8 +1866,7 @@ __global__ void merge_multi_chunks_kernel(
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, hid = threadIdx.y;
|
||||
const int qid = blockIdx.x;
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -2240,7 +2239,8 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2259,9 +2259,8 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
|
Reference in New Issue
Block a user