[Feature][SpeculativeDecoding]Support tree-attention (#3514)

* support tree-attention

* fix merge bug

* fix unit-test api

* fix merge bug
This commit is contained in:
freeliuzc
2025-08-22 13:36:41 +08:00
committed by GitHub
parent cc88671507
commit 76759108c9
5 changed files with 446 additions and 20 deletions

View File

@@ -247,13 +247,16 @@ __global__ void multi_query_append_attention_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -410,6 +413,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -423,7 +427,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -544,8 +549,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
kv_len - q_len,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
@@ -615,11 +619,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
@@ -1069,6 +1075,13 @@ void MultiQueryAppendAttention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
@@ -1111,6 +1124,8 @@ void MultiQueryAppendAttention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1123,7 +1138,8 @@ void MultiQueryAppendAttention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1180,6 +1196,8 @@ void MultiQueryAppendAttention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1192,7 +1210,8 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();

View File

@@ -335,11 +335,13 @@ __global__ void multi_query_append_attention_c4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
@@ -509,6 +511,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -522,7 +525,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -707,8 +711,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
kv_len - q_len,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
@@ -792,11 +795,13 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
@@ -1294,6 +1299,13 @@ void MultiQueryAppendC4Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
@@ -1343,6 +1355,8 @@ void MultiQueryAppendC4Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1355,7 +1369,8 @@ void MultiQueryAppendC4Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1420,6 +1435,8 @@ void MultiQueryAppendC4Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1432,7 +1449,8 @@ void MultiQueryAppendC4Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {

View File

@@ -302,11 +302,13 @@ __global__ void multi_query_append_attention_c8_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
@@ -478,6 +480,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -491,7 +494,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -732,13 +736,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(q_base_seq_id_this_block,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -1262,6 +1269,13 @@ void MultiQueryAppendC8Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
@@ -1326,6 +1340,8 @@ void MultiQueryAppendC8Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1338,7 +1354,8 @@ void MultiQueryAppendC8Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1397,6 +1414,8 @@ void MultiQueryAppendC8Attention(
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1409,7 +1428,8 @@ void MultiQueryAppendC8Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
attn_mask_len);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {

View File

@@ -905,11 +905,13 @@ template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
bool IS_SYSTEM = false>
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
__device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_idx_base,
const uint32_t kv_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
@@ -933,7 +935,13 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
}
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
@@ -941,6 +949,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +