[Feature][SpeculativeDecoding]Support tree-attention (#3514)

* support tree-attention

* fix merge bug

* fix unit-test api

* fix merge bug
This commit is contained in:
freeliuzc
2025-08-22 13:36:41 +08:00
committed by GitHub
parent cc88671507
commit 76759108c9
5 changed files with 446 additions and 20 deletions

View File

@@ -247,13 +247,16 @@ __global__ void multi_query_append_attention_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -410,6 +413,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -423,7 +427,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -544,8 +549,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
kv_len - q_len,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
@@ -615,11 +619,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
@@ -1069,6 +1075,13 @@ void MultiQueryAppendAttention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
@@ -1111,6 +1124,8 @@ void MultiQueryAppendAttention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1123,7 +1138,8 @@ void MultiQueryAppendAttention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1180,6 +1196,8 @@ void MultiQueryAppendAttention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1192,7 +1210,8 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();

View File

@@ -335,11 +335,13 @@ __global__ void multi_query_append_attention_c4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
@@ -509,6 +511,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -522,7 +525,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -707,8 +711,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
kv_len - q_len,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
@@ -792,11 +795,13 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
@@ -1294,6 +1299,13 @@ void MultiQueryAppendC4Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
@@ -1343,6 +1355,8 @@ void MultiQueryAppendC4Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1355,7 +1369,8 @@ void MultiQueryAppendC4Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1420,6 +1435,8 @@ void MultiQueryAppendC4Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1432,7 +1449,8 @@ void MultiQueryAppendC4Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {

View File

@@ -302,11 +302,13 @@ __global__ void multi_query_append_attention_c8_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
@@ -478,6 +480,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -491,7 +494,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -732,13 +736,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -1262,6 +1269,13 @@ void MultiQueryAppendC8Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
@@ -1326,6 +1340,8 @@ void MultiQueryAppendC8Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1338,7 +1354,8 @@ void MultiQueryAppendC8Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1397,6 +1414,8 @@ void MultiQueryAppendC8Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1409,7 +1428,8 @@ void MultiQueryAppendC8Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {

View File

@@ -905,11 +905,13 @@ template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
bool IS_SYSTEM = false>
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
__device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_idx_base,
const uint32_t kv_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
@@ -933,7 +935,13 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
}
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
@@ -941,6 +949,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +

View File

@@ -0,0 +1,360 @@
import math
import time
import numpy as np
import paddle
import paddle.nn.functional as F
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
)
paddle.seed(0)
max_seq_len = 32768
encoder_max_partition_size = max_seq_len
max_partition_size = max_seq_len
max_dec_len = 1024
bsz = 64
run_time = 10
warm_up = 2
block_size = 64
head_dim = 128
num_q_head = 20
num_kv_head = 4
dtype = "bfloat16"
rope_3d = False
use_neox_rotary_style = False
CURRENT_Q = [None]
TOTAL_K = []
TOTAL_V = []
def split_qkv(qkv, bsz, seq_len, num_q_head, num_kv_head, head_dim):
# [token_num, (num_q_head + 2 * num_kv_head) * head_dim]
qkv = qkv.reshape([bsz, seq_len, -1, head_dim])
q = qkv[:, :, :num_q_head, :]
# [bsz, seq_len, num_q_head, head_dim]
CURRENT_Q[0] = q
# [bsz, seq_len, num_kv_head, head_dim]
k = qkv[:, :, num_q_head : num_q_head + num_kv_head, :]
TOTAL_K.append(k)
# [bsz, seq_len, num_kv_head, head_dim]
v = qkv[:, :, num_q_head + num_kv_head :, :]
TOTAL_V.append(v)
def get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder):
batch_id_per_token = []
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
cum_seq_len_q = 0
cum_seq_len_k = 0
for i in range(bsz):
seq_len_now = seq_lens_this_time[i]
seq_len_dec_now = seq_lens_decoder[i]
for j in range(seq_len_now):
batch_id_per_token.append(i)
cum_seq_len_q += seq_len_now
cum_seq_len_k += seq_len_now + seq_len_dec_now
cu_seqlens_q[i + 1] = cum_seq_len_q
cu_seqlens_k[i + 1] = cum_seq_len_k
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
# block_table
block_num_per_seq = (max_seq_len + block_size - 1) // block_size
max_block_num = block_num_per_seq * bsz
cache_shape = (
max_block_num,
num_kv_head,
block_size,
head_dim,
)
cache_k = paddle.zeros(shape=cache_shape).astype(dtype)
cache_v = paddle.zeros(shape=cache_shape).astype(dtype)
block_tables = paddle.zeros(shape=(bsz, block_num_per_seq), dtype="int32")
free_list = list(range(max_block_num - 1, -1, -1))
for i in range(bsz):
need_block_num = (max_seq_len + block_size - 1) // block_size
for j in range(need_block_num):
block_id = free_list.pop()
block_tables[i, j] = block_id
def ref_attention(q, k, v, num_q_head, num_kv_head, head_dim, mask):
q = q.transpose([0, 2, 1, 3])
if len(k) > 1:
k = paddle.concat(k, axis=1)
else:
k = k[0]
k = k.transpose([0, 2, 1, 3])
if len(v) > 1:
v = paddle.concat(v, axis=1)
else:
v = v[0]
v = v.transpose([0, 2, 1, 3])
total_len = k.shape[2]
scores = q.reshape([bsz, num_kv_head, -1, head_dim]) @ k.transpose([0, 1, 3, 2]) * (1.0 / math.sqrt(head_dim))
scores = scores.reshape([bsz, num_q_head, -1, total_len])
if mask is not None:
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,q_len,kv_len]
elif mask.ndim == 3:
mask = mask.unsqueeze(1) # [bsz,1,q_len,kv_len]
scores = paddle.add(scores, mask)
weights = F.softmax(scores, axis=-1)
o = weights.reshape([bsz, num_kv_head, -1, total_len]) @ v
return o.reshape([bsz, num_q_head, -1, head_dim]).transpose([0, 2, 1, 3]).reshape([-1, num_q_head, head_dim])
def clear_param():
global CURRENT_Q, TOTAL_K, TOTAL_V
CURRENT_Q = [None]
TOTAL_K = []
TOTAL_V = []
def test_append_c16_attention(q_len, kv_len, prefill=False, attn_mask=None):
if prefill:
seq_lens_enc = [
q_len,
] * bsz
else:
seq_lens_enc = [
0,
] * bsz
seq_lens_dec = [
kv_len,
] * bsz
seq_lens_cur = [
q_len,
] * bsz
token_num = sum(seq_lens_cur)
decoder_step_token_num = 1 if prefill else q_len
seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32")
seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32")
seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32")
batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder)
# random data
qkv_varlen_shape = [token_num, (num_q_head + 2 * num_kv_head) * head_dim]
rotary_embs_shape = [2, 1, max_seq_len, 1, head_dim if use_neox_rotary_style else head_dim // 2]
# qkv_bias_shape = [num_q_head + 2 * num_kv_head, head_dim]
qkv = paddle.randn(shape=qkv_varlen_shape).astype(dtype)
# save q, k, v for ref
split_qkv(qkv, bsz, q_len, num_q_head, num_kv_head, head_dim)
rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32")
rotary_embs[0, :, :, :, :] = 1
rotary_embs[1, :, :, :, :] = 0
# qkv_scale = None
# qkv_bias = None
cache_k_scale = None
cache_v_scale = None
cache_k_out_scale = None
cache_v_out_scale = None
# shift_bias = None
# smooth_weight = None
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decode_max_tile_size = (
bsz
* (decoder_step_token_num * (num_q_head // num_kv_head) + decoder_block_shape_q - 1)
/ decoder_block_shape_q
)
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
paddle.device.synchronize()
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv,
) = get_block_shape_and_split_kv_block(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
num_q_head // num_kv_head,
block_size,
decoder_step_token_num,
)
s_time = 0
for i in range(run_time + warm_up):
if i == warm_up:
s_time = time.time()
out = append_attention(
qkv,
cache_k,
cache_v,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
max_len_kv,
rotary_embs,
attn_mask, # attn_mask
None,
None,
cache_k_scale,
cache_v_scale,
cache_k_out_scale,
cache_v_out_scale,
None, # cache_k_zp
None, # cache_v_zp
None,
None,
None,
None,
None,
None,
1e-6,
"bf16",
"none", # cache_quant_type
use_neox_rotary_style,
rope_3d,
max_seq_len,
0.0,
0.0,
-1.0, # out_linear_in_scale
encoder_block_shape_q, # encoder_block_shape_q
decoder_block_shape_q, # decoder_block_shape_q
max_partition_size, # max_partition_size
encoder_max_partition_size, # encoder_max_partition_size
decoder_step_token_num, # speculate_max_draft_token_num
True, # causal
decoder_step_token_num > 1, # speculate_decoder
)
paddle.device.synchronize()
e_time = time.time()
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / run_time):.2f}")
return out[0].reshape([token_num, num_q_head, head_dim])
def test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim):
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
test_append_c16_attention(prefill_len, 0, True)
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False)
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)
def test_mask(num_q_head, num_kv_head, head_dim):
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
mask_append_attn = mask[:, :, prefill_len:]
mask_append_attn = paddle.where(
mask_append_attn == 1,
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
)
test_append_c16_attention(prefill_len, 0, True)
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)
def test_tree_mask(num_q_head, num_kv_head, head_dim):
# tree
# [N, N+1, N+1, N+2, N+2]
# N [0, -inf, -inf, -inf, -inf]
# N+1 [0, 0, -inf, -inf, -inf]
# N+1 [0, -inf, 0, -inf, -inf]
# N+2 [0, 0, -inf, 0, -inf]
# N+2 [0, -inf, 0, -inf, 0]
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask[:, 2, prefill_len + 1] = 0
mask[:, 3, prefill_len + 2] = 0
mask[:, 4, prefill_len + 1] = 0
mask[:, 4, prefill_len + 3] = 0
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
mask_append_attn = mask[:, :, prefill_len:]
mask_append_attn = paddle.where(
mask_append_attn == 1,
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
)
test_append_c16_attention(prefill_len, 0, True)
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)
if __name__ == "__main__":
test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim)
clear_param()
test_mask(num_q_head, num_kv_head, head_dim)
clear_param()
test_tree_mask(num_q_head, num_kv_head, head_dim)