mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature][SpeculativeDecoding]Support tree-attention (#3514)
* support tree-attention * fix merge bug * fix unit-test api * fix merge bug
This commit is contained in:
@@ -247,13 +247,16 @@ __global__ void multi_query_append_attention_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -410,6 +413,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -423,7 +427,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
||||
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
||||
@@ -544,8 +549,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
kv_len - q_len,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
@@ -615,11 +619,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
@@ -1069,6 +1075,13 @@ void MultiQueryAppendAttention(
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
@@ -1111,6 +1124,8 @@ void MultiQueryAppendAttention(
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1123,7 +1138,8 @@ void MultiQueryAppendAttention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1180,6 +1196,8 @@ void MultiQueryAppendAttention(
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1192,7 +1210,8 @@ void MultiQueryAppendAttention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
|
@@ -335,11 +335,13 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
@@ -509,6 +511,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -522,7 +525,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||
@@ -707,8 +711,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
kv_len - q_len,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
@@ -792,11 +795,13 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
@@ -1294,6 +1299,13 @@ void MultiQueryAppendC4Attention(
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 0) {
|
||||
@@ -1343,6 +1355,8 @@ void MultiQueryAppendC4Attention(
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1355,7 +1369,8 @@ void MultiQueryAppendC4Attention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1420,6 +1435,8 @@ void MultiQueryAppendC4Attention(
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1432,7 +1449,8 @@ void MultiQueryAppendC4Attention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
|
@@ -302,11 +302,13 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
@@ -478,6 +480,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -491,7 +494,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
@@ -732,13 +736,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -1262,6 +1269,13 @@ void MultiQueryAppendC8Attention(
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 0) {
|
||||
@@ -1326,6 +1340,8 @@ void MultiQueryAppendC8Attention(
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1338,7 +1354,8 @@ void MultiQueryAppendC8Attention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1397,6 +1414,8 @@ void MultiQueryAppendC8Attention(
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1409,7 +1428,8 @@ void MultiQueryAppendC8Attention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
|
@@ -905,11 +905,13 @@ template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
bool IS_SYSTEM = false>
|
||||
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
__device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t qo_idx_base,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t qo_len,
|
||||
const uint32_t kv_len,
|
||||
const uint32_t chunk_end,
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
@@ -933,7 +935,13 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
out_of_boundary |= mask;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
|
||||
@@ -941,6 +949,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
||||
}
|
||||
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
|
||||
} else {
|
||||
const uint32_t q_idx = qo_idx_base,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
|
360
tests/operators/test_tree_mask.py
Normal file
360
tests/operators/test_tree_mask.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
paddle.seed(0)
|
||||
|
||||
max_seq_len = 32768
|
||||
encoder_max_partition_size = max_seq_len
|
||||
max_partition_size = max_seq_len
|
||||
|
||||
max_dec_len = 1024
|
||||
bsz = 64
|
||||
run_time = 10
|
||||
warm_up = 2
|
||||
block_size = 64
|
||||
head_dim = 128
|
||||
num_q_head = 20
|
||||
num_kv_head = 4
|
||||
dtype = "bfloat16"
|
||||
|
||||
rope_3d = False
|
||||
use_neox_rotary_style = False
|
||||
CURRENT_Q = [None]
|
||||
TOTAL_K = []
|
||||
TOTAL_V = []
|
||||
|
||||
|
||||
def split_qkv(qkv, bsz, seq_len, num_q_head, num_kv_head, head_dim):
|
||||
# [token_num, (num_q_head + 2 * num_kv_head) * head_dim]
|
||||
qkv = qkv.reshape([bsz, seq_len, -1, head_dim])
|
||||
q = qkv[:, :, :num_q_head, :]
|
||||
# [bsz, seq_len, num_q_head, head_dim]
|
||||
CURRENT_Q[0] = q
|
||||
|
||||
# [bsz, seq_len, num_kv_head, head_dim]
|
||||
k = qkv[:, :, num_q_head : num_q_head + num_kv_head, :]
|
||||
TOTAL_K.append(k)
|
||||
|
||||
# [bsz, seq_len, num_kv_head, head_dim]
|
||||
v = qkv[:, :, num_q_head + num_kv_head :, :]
|
||||
TOTAL_V.append(v)
|
||||
|
||||
|
||||
def get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder):
|
||||
batch_id_per_token = []
|
||||
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||
cum_seq_len_q = 0
|
||||
cum_seq_len_k = 0
|
||||
for i in range(bsz):
|
||||
seq_len_now = seq_lens_this_time[i]
|
||||
seq_len_dec_now = seq_lens_decoder[i]
|
||||
for j in range(seq_len_now):
|
||||
batch_id_per_token.append(i)
|
||||
cum_seq_len_q += seq_len_now
|
||||
cum_seq_len_k += seq_len_now + seq_len_dec_now
|
||||
cu_seqlens_q[i + 1] = cum_seq_len_q
|
||||
cu_seqlens_k[i + 1] = cum_seq_len_k
|
||||
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
|
||||
|
||||
|
||||
# block_table
|
||||
block_num_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
max_block_num = block_num_per_seq * bsz
|
||||
cache_shape = (
|
||||
max_block_num,
|
||||
num_kv_head,
|
||||
block_size,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
cache_k = paddle.zeros(shape=cache_shape).astype(dtype)
|
||||
cache_v = paddle.zeros(shape=cache_shape).astype(dtype)
|
||||
|
||||
block_tables = paddle.zeros(shape=(bsz, block_num_per_seq), dtype="int32")
|
||||
|
||||
free_list = list(range(max_block_num - 1, -1, -1))
|
||||
|
||||
for i in range(bsz):
|
||||
need_block_num = (max_seq_len + block_size - 1) // block_size
|
||||
for j in range(need_block_num):
|
||||
block_id = free_list.pop()
|
||||
block_tables[i, j] = block_id
|
||||
|
||||
|
||||
def ref_attention(q, k, v, num_q_head, num_kv_head, head_dim, mask):
|
||||
q = q.transpose([0, 2, 1, 3])
|
||||
if len(k) > 1:
|
||||
k = paddle.concat(k, axis=1)
|
||||
else:
|
||||
k = k[0]
|
||||
k = k.transpose([0, 2, 1, 3])
|
||||
if len(v) > 1:
|
||||
v = paddle.concat(v, axis=1)
|
||||
else:
|
||||
v = v[0]
|
||||
v = v.transpose([0, 2, 1, 3])
|
||||
total_len = k.shape[2]
|
||||
|
||||
scores = q.reshape([bsz, num_kv_head, -1, head_dim]) @ k.transpose([0, 1, 3, 2]) * (1.0 / math.sqrt(head_dim))
|
||||
scores = scores.reshape([bsz, num_q_head, -1, total_len])
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,q_len,kv_len]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1) # [bsz,1,q_len,kv_len]
|
||||
scores = paddle.add(scores, mask)
|
||||
weights = F.softmax(scores, axis=-1)
|
||||
|
||||
o = weights.reshape([bsz, num_kv_head, -1, total_len]) @ v
|
||||
return o.reshape([bsz, num_q_head, -1, head_dim]).transpose([0, 2, 1, 3]).reshape([-1, num_q_head, head_dim])
|
||||
|
||||
|
||||
def clear_param():
|
||||
global CURRENT_Q, TOTAL_K, TOTAL_V
|
||||
CURRENT_Q = [None]
|
||||
TOTAL_K = []
|
||||
TOTAL_V = []
|
||||
|
||||
|
||||
def test_append_c16_attention(q_len, kv_len, prefill=False, attn_mask=None):
|
||||
if prefill:
|
||||
seq_lens_enc = [
|
||||
q_len,
|
||||
] * bsz
|
||||
else:
|
||||
seq_lens_enc = [
|
||||
0,
|
||||
] * bsz
|
||||
|
||||
seq_lens_dec = [
|
||||
kv_len,
|
||||
] * bsz
|
||||
seq_lens_cur = [
|
||||
q_len,
|
||||
] * bsz
|
||||
token_num = sum(seq_lens_cur)
|
||||
decoder_step_token_num = 1 if prefill else q_len
|
||||
|
||||
seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32")
|
||||
seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32")
|
||||
seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32")
|
||||
|
||||
batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder)
|
||||
|
||||
# random data
|
||||
qkv_varlen_shape = [token_num, (num_q_head + 2 * num_kv_head) * head_dim]
|
||||
|
||||
rotary_embs_shape = [2, 1, max_seq_len, 1, head_dim if use_neox_rotary_style else head_dim // 2]
|
||||
# qkv_bias_shape = [num_q_head + 2 * num_kv_head, head_dim]
|
||||
|
||||
qkv = paddle.randn(shape=qkv_varlen_shape).astype(dtype)
|
||||
|
||||
# save q, k, v for ref
|
||||
split_qkv(qkv, bsz, q_len, num_q_head, num_kv_head, head_dim)
|
||||
|
||||
rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32")
|
||||
rotary_embs[0, :, :, :, :] = 1
|
||||
rotary_embs[1, :, :, :, :] = 0
|
||||
|
||||
# qkv_scale = None
|
||||
# qkv_bias = None
|
||||
|
||||
cache_k_scale = None
|
||||
cache_v_scale = None
|
||||
cache_k_out_scale = None
|
||||
cache_v_out_scale = None
|
||||
# shift_bias = None
|
||||
# smooth_weight = None
|
||||
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
decode_max_tile_size = (
|
||||
bsz
|
||||
* (decoder_step_token_num * (num_q_head // num_kv_head) + decoder_block_shape_q - 1)
|
||||
/ decoder_block_shape_q
|
||||
)
|
||||
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
paddle.device.synchronize()
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
num_q_head // num_kv_head,
|
||||
block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
s_time = 0
|
||||
for i in range(run_time + warm_up):
|
||||
if i == warm_up:
|
||||
s_time = time.time()
|
||||
out = append_attention(
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
rotary_embs,
|
||||
attn_mask, # attn_mask
|
||||
None,
|
||||
None,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_out_scale,
|
||||
cache_v_out_scale,
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1e-6,
|
||||
"bf16",
|
||||
"none", # cache_quant_type
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_seq_len,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0, # out_linear_in_scale
|
||||
encoder_block_shape_q, # encoder_block_shape_q
|
||||
decoder_block_shape_q, # decoder_block_shape_q
|
||||
max_partition_size, # max_partition_size
|
||||
encoder_max_partition_size, # encoder_max_partition_size
|
||||
decoder_step_token_num, # speculate_max_draft_token_num
|
||||
True, # causal
|
||||
decoder_step_token_num > 1, # speculate_decoder
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
e_time = time.time()
|
||||
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / run_time):.2f}")
|
||||
return out[0].reshape([token_num, num_q_head, head_dim])
|
||||
|
||||
|
||||
def test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
test_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False)
|
||||
|
||||
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
||||
|
||||
def test_mask(num_q_head, num_kv_head, head_dim):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
|
||||
mask_append_attn = mask[:, :, prefill_len:]
|
||||
mask_append_attn = paddle.where(
|
||||
mask_append_attn == 1,
|
||||
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
|
||||
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
|
||||
)
|
||||
|
||||
test_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
|
||||
|
||||
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
||||
|
||||
def test_tree_mask(num_q_head, num_kv_head, head_dim):
|
||||
# tree
|
||||
# [N, N+1, N+1, N+2, N+2]
|
||||
# N [0, -inf, -inf, -inf, -inf]
|
||||
# N+1 [0, 0, -inf, -inf, -inf]
|
||||
# N+1 [0, -inf, 0, -inf, -inf]
|
||||
# N+2 [0, 0, -inf, 0, -inf]
|
||||
# N+2 [0, -inf, 0, -inf, 0]
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask[:, 2, prefill_len + 1] = 0
|
||||
mask[:, 3, prefill_len + 2] = 0
|
||||
mask[:, 4, prefill_len + 1] = 0
|
||||
mask[:, 4, prefill_len + 3] = 0
|
||||
|
||||
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
|
||||
mask_append_attn = mask[:, :, prefill_len:]
|
||||
mask_append_attn = paddle.where(
|
||||
mask_append_attn == 1,
|
||||
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
|
||||
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
|
||||
)
|
||||
|
||||
test_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
|
||||
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim)
|
||||
clear_param()
|
||||
|
||||
test_mask(num_q_head, num_kv_head, head_dim)
|
||||
clear_param()
|
||||
|
||||
test_tree_mask(num_q_head, num_kv_head, head_dim)
|
Reference in New Issue
Block a user