From 76759108c97cda21563173869e3bbf81fb9addce Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Fri, 22 Aug 2025 13:36:41 +0800 Subject: [PATCH] [Feature][SpeculativeDecoding]Support tree-attention (#3514) * support tree-attention * fix merge bug * fix unit-test api * fix merge bug --- .../append_attn/append_attention_c16_impl.cuh | 33 +- .../append_attn/append_attention_c4_impl.cuh | 32 +- .../append_attn/append_attention_c8_impl.cuh | 30 +- .../append_attn/append_attention_func.cuh | 11 +- tests/operators/test_tree_mask.py | 360 ++++++++++++++++++ 5 files changed, 446 insertions(+), 20 deletions(-) create mode 100644 tests/operators/test_tree_mask.py diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index 815d0947d..ffef6bc8b 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -247,13 +247,16 @@ __global__ void multi_query_append_attention_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(q_base_seq_id_this_block, + num_frags_z>(nullptr, + q_base_seq_id_this_block, kv_idx_base, q_len, kv_len, chunk_end, + -1, s_frag, mask_offset_this_seq); + } // update m,d @@ -410,6 +413,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -423,7 +427,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); @@ -544,8 +549,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t mask_check_iteration = (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, + kv_len - q_len, chunk_start))) : mask_offset ? 0 : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); @@ -615,11 +619,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(q_base_seq_id_this_block, + num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, + q_base_seq_id_this_block, kv_idx_base + wid * num_frags_z * 16, q_len, kv_len, chunk_end, + attn_mask_len, s_frag, mask_offset_this_seq); } @@ -1069,6 +1075,13 @@ void MultiQueryAppendAttention( chunk_size = static_cast(encoder_max_partition_size); } + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); @@ -1111,6 +1124,8 @@ void MultiQueryAppendAttention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1123,7 +1138,8 @@ void MultiQueryAppendAttention( nullptr, nullptr, reinterpret_cast(out->data()), - speculate_max_draft_token_num); + speculate_max_draft_token_num, + attn_mask_len); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; if (is_decoder) { @@ -1180,6 +1196,8 @@ void MultiQueryAppendAttention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1192,7 +1210,8 @@ void MultiQueryAppendAttention( static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), reinterpret_cast(out->data()), - speculate_max_draft_token_num); + speculate_max_draft_token_num, + attn_mask_len); // merge constexpr int vec_size = num_elems_per_128b(); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 6f403ed47..509544253 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -335,11 +335,13 @@ __global__ void multi_query_append_attention_c4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(q_base_seq_id_this_block, + num_frags_z>(nullptr, + q_base_seq_id_this_block, kv_idx_base, q_len, kv_len, chunk_end, + -1, s_frag, mask_offset_this_seq); } @@ -509,6 +511,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -522,7 +525,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); constexpr uint32_t num_vecs_per_head_k = HEAD_DIM / 2 / num_elems_per_128b(); @@ -707,8 +711,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const uint32_t mask_check_iteration = (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, + kv_len - q_len, chunk_start))) : mask_offset ? 0 : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); @@ -792,11 +795,13 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(q_base_seq_id_this_block, + num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, + q_base_seq_id_this_block, kv_idx_base + wid * num_frags_z * 16, q_len, kv_len, chunk_end, + attn_mask_len, s_frag, mask_offset_this_seq); } @@ -1294,6 +1299,13 @@ void MultiQueryAppendC4Attention( } const int num_chunks = div_up(max_seq_len, chunk_size); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); if (num_chunks <= 0) { @@ -1343,6 +1355,8 @@ void MultiQueryAppendC4Attention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1355,7 +1369,8 @@ void MultiQueryAppendC4Attention( nullptr, nullptr, reinterpret_cast(out->data()), - speculate_max_draft_token_num); + speculate_max_draft_token_num, + attn_mask_len); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; if (is_decoder) { @@ -1420,6 +1435,8 @@ void MultiQueryAppendC4Attention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1432,7 +1449,8 @@ void MultiQueryAppendC4Attention( static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), reinterpret_cast(out->data()), - speculate_max_draft_token_num); + speculate_max_draft_token_num, + attn_mask_len); // merge constexpr int vec_size = num_elems_per_128b(); if (is_decoder) { diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index a17d5ae29..2c44c917c 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -302,11 +302,13 @@ __global__ void multi_query_append_attention_c8_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(q_base_seq_id_this_block, + num_frags_z>(nullptr, + q_base_seq_id_this_block, kv_idx_base, q_len, kv_len, chunk_end, + -1, s_frag, mask_offset_this_seq); } @@ -478,6 +480,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -491,7 +494,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); constexpr uint32_t num_vecs_per_head_k = HEAD_DIM / num_elems_per_128b(); @@ -732,13 +736,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(q_base_seq_id_this_block, + num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, + q_base_seq_id_this_block, kv_idx_base + wid * num_frags_z * 16, q_len, kv_len, chunk_end, + attn_mask_len, s_frag, mask_offset_this_seq); + } // update m,d @@ -1262,6 +1269,13 @@ void MultiQueryAppendC8Attention( } const int num_chunks = div_up(max_seq_len, chunk_size); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); if (num_chunks <= 0) { @@ -1326,6 +1340,8 @@ void MultiQueryAppendC8Attention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1338,7 +1354,8 @@ void MultiQueryAppendC8Attention( nullptr, nullptr, reinterpret_cast(out->data()), - speculate_max_draft_token_num); + speculate_max_draft_token_num, + attn_mask_len); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; if (is_decoder) { @@ -1397,6 +1414,8 @@ void MultiQueryAppendC8Attention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1409,7 +1428,8 @@ void MultiQueryAppendC8Attention( static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), reinterpret_cast(out->data()), - speculate_max_draft_token_num); + speculate_max_draft_token_num, + attn_mask_len); // merge constexpr int vec_size = num_elems_per_128b(); if (is_decoder) { diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 060340395..bb32d288a 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -905,11 +905,13 @@ template -__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, + const uint32_t attn_mask_len, float (*s_frag)[num_frags_z][8], const int *mask_offset = nullptr) { const uint32_t tx = threadIdx.x; @@ -933,7 +935,13 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, (causal ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } } + if constexpr (std::is_same::value) { s_frag[fx][fz][reg_id] = out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; @@ -941,6 +949,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, s_frag[fx][fz][reg_id] = out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; } + // printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast(s_frag[fx][fz][reg_id]), int(out_of_boundary)); } else { const uint32_t q_idx = qo_idx_base, kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py new file mode 100644 index 000000000..ee8be1b3c --- /dev/null +++ b/tests/operators/test_tree_mask.py @@ -0,0 +1,360 @@ +import math +import time + +import numpy as np +import paddle +import paddle.nn.functional as F + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, + get_block_shape_and_split_kv_block, +) + +paddle.seed(0) + +max_seq_len = 32768 +encoder_max_partition_size = max_seq_len +max_partition_size = max_seq_len + +max_dec_len = 1024 +bsz = 64 +run_time = 10 +warm_up = 2 +block_size = 64 +head_dim = 128 +num_q_head = 20 +num_kv_head = 4 +dtype = "bfloat16" + +rope_3d = False +use_neox_rotary_style = False +CURRENT_Q = [None] +TOTAL_K = [] +TOTAL_V = [] + + +def split_qkv(qkv, bsz, seq_len, num_q_head, num_kv_head, head_dim): + # [token_num, (num_q_head + 2 * num_kv_head) * head_dim] + qkv = qkv.reshape([bsz, seq_len, -1, head_dim]) + q = qkv[:, :, :num_q_head, :] + # [bsz, seq_len, num_q_head, head_dim] + CURRENT_Q[0] = q + + # [bsz, seq_len, num_kv_head, head_dim] + k = qkv[:, :, num_q_head : num_q_head + num_kv_head, :] + TOTAL_K.append(k) + + # [bsz, seq_len, num_kv_head, head_dim] + v = qkv[:, :, num_q_head + num_kv_head :, :] + TOTAL_V.append(v) + + +def get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder): + batch_id_per_token = [] + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + cum_seq_len_q = 0 + cum_seq_len_k = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i] + seq_len_dec_now = seq_lens_decoder[i] + for j in range(seq_len_now): + batch_id_per_token.append(i) + cum_seq_len_q += seq_len_now + cum_seq_len_k += seq_len_now + seq_len_dec_now + cu_seqlens_q[i + 1] = cum_seq_len_q + cu_seqlens_k[i + 1] = cum_seq_len_k + return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k + + +# block_table +block_num_per_seq = (max_seq_len + block_size - 1) // block_size +max_block_num = block_num_per_seq * bsz +cache_shape = ( + max_block_num, + num_kv_head, + block_size, + head_dim, +) + +cache_k = paddle.zeros(shape=cache_shape).astype(dtype) +cache_v = paddle.zeros(shape=cache_shape).astype(dtype) + +block_tables = paddle.zeros(shape=(bsz, block_num_per_seq), dtype="int32") + +free_list = list(range(max_block_num - 1, -1, -1)) + +for i in range(bsz): + need_block_num = (max_seq_len + block_size - 1) // block_size + for j in range(need_block_num): + block_id = free_list.pop() + block_tables[i, j] = block_id + + +def ref_attention(q, k, v, num_q_head, num_kv_head, head_dim, mask): + q = q.transpose([0, 2, 1, 3]) + if len(k) > 1: + k = paddle.concat(k, axis=1) + else: + k = k[0] + k = k.transpose([0, 2, 1, 3]) + if len(v) > 1: + v = paddle.concat(v, axis=1) + else: + v = v[0] + v = v.transpose([0, 2, 1, 3]) + total_len = k.shape[2] + + scores = q.reshape([bsz, num_kv_head, -1, head_dim]) @ k.transpose([0, 1, 3, 2]) * (1.0 / math.sqrt(head_dim)) + scores = scores.reshape([bsz, num_q_head, -1, total_len]) + + if mask is not None: + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,q_len,kv_len] + elif mask.ndim == 3: + mask = mask.unsqueeze(1) # [bsz,1,q_len,kv_len] + scores = paddle.add(scores, mask) + weights = F.softmax(scores, axis=-1) + + o = weights.reshape([bsz, num_kv_head, -1, total_len]) @ v + return o.reshape([bsz, num_q_head, -1, head_dim]).transpose([0, 2, 1, 3]).reshape([-1, num_q_head, head_dim]) + + +def clear_param(): + global CURRENT_Q, TOTAL_K, TOTAL_V + CURRENT_Q = [None] + TOTAL_K = [] + TOTAL_V = [] + + +def test_append_c16_attention(q_len, kv_len, prefill=False, attn_mask=None): + if prefill: + seq_lens_enc = [ + q_len, + ] * bsz + else: + seq_lens_enc = [ + 0, + ] * bsz + + seq_lens_dec = [ + kv_len, + ] * bsz + seq_lens_cur = [ + q_len, + ] * bsz + token_num = sum(seq_lens_cur) + decoder_step_token_num = 1 if prefill else q_len + + seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32") + seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32") + seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32") + + batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder) + + # random data + qkv_varlen_shape = [token_num, (num_q_head + 2 * num_kv_head) * head_dim] + + rotary_embs_shape = [2, 1, max_seq_len, 1, head_dim if use_neox_rotary_style else head_dim // 2] + # qkv_bias_shape = [num_q_head + 2 * num_kv_head, head_dim] + + qkv = paddle.randn(shape=qkv_varlen_shape).astype(dtype) + + # save q, k, v for ref + split_qkv(qkv, bsz, q_len, num_q_head, num_kv_head, head_dim) + + rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32") + rotary_embs[0, :, :, :, :] = 1 + rotary_embs[1, :, :, :, :] = 0 + + # qkv_scale = None + # qkv_bias = None + + cache_k_scale = None + cache_v_scale = None + cache_k_out_scale = None + cache_v_out_scale = None + # shift_bias = None + # smooth_weight = None + + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + + decode_max_tile_size = ( + bsz + * (decoder_step_token_num * (num_q_head // num_kv_head) + decoder_block_shape_q - 1) + / decoder_block_shape_q + ) + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() + max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + paddle.device.synchronize() + ( + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + max_len_kv, + ) = get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_len_tensor_cpu, + encoder_block_shape_q, + decoder_block_shape_q, + num_q_head // num_kv_head, + block_size, + decoder_step_token_num, + ) + s_time = 0 + for i in range(run_time + warm_up): + if i == warm_up: + s_time = time.time() + out = append_attention( + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_len_tensor_cpu, + max_len_kv, + rotary_embs, + attn_mask, # attn_mask + None, + None, + cache_k_scale, + cache_v_scale, + cache_k_out_scale, + cache_v_out_scale, + None, # cache_k_zp + None, # cache_v_zp + None, + None, + None, + None, + None, + None, + 1e-6, + "bf16", + "none", # cache_quant_type + use_neox_rotary_style, + rope_3d, + max_seq_len, + 0.0, + 0.0, + -1.0, # out_linear_in_scale + encoder_block_shape_q, # encoder_block_shape_q + decoder_block_shape_q, # decoder_block_shape_q + max_partition_size, # max_partition_size + encoder_max_partition_size, # encoder_max_partition_size + decoder_step_token_num, # speculate_max_draft_token_num + True, # causal + decoder_step_token_num > 1, # speculate_decoder + ) + paddle.device.synchronize() + e_time = time.time() + print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / run_time):.2f}") + return out[0].reshape([token_num, num_q_head, head_dim]) + + +def test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + test_append_c16_attention(prefill_len, 0, True) + dec_out = test_append_c16_attention(dec_len_q, prefill_len, False) + + ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) + + +def test_mask(num_q_head, num_kv_head, head_dim): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + + mask_append_attn = mask[:, :, prefill_len:] + mask_append_attn = paddle.where( + mask_append_attn == 1, + paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), + paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), + ) + + test_append_c16_attention(prefill_len, 0, True) + dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) + + ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref) + + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) + + +def test_tree_mask(num_q_head, num_kv_head, head_dim): + # tree + # [N, N+1, N+1, N+2, N+2] + # N [0, -inf, -inf, -inf, -inf] + # N+1 [0, 0, -inf, -inf, -inf] + # N+1 [0, -inf, 0, -inf, -inf] + # N+2 [0, 0, -inf, 0, -inf] + # N+2 [0, -inf, 0, -inf, 0] + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask[:, 2, prefill_len + 1] = 0 + mask[:, 3, prefill_len + 2] = 0 + mask[:, 4, prefill_len + 1] = 0 + mask[:, 4, prefill_len + 3] = 0 + + mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + + mask_append_attn = mask[:, :, prefill_len:] + mask_append_attn = paddle.where( + mask_append_attn == 1, + paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), + paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), + ) + + test_append_c16_attention(prefill_len, 0, True) + dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) + ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) + + +if __name__ == "__main__": + + test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim) + clear_param() + + test_mask(num_q_head, num_kv_head, head_dim) + clear_param() + + test_tree_mask(num_q_head, num_kv_head, head_dim)