make append_attn supports mask_offset (#3138)

* make append_attn supports mask_offset

* add unittest
This commit is contained in:
lzy
2025-08-14 18:40:55 +08:00
committed by GitHub
parent 6031f9a5f5
commit 1e06b9fa6d
10 changed files with 88 additions and 20 deletions

View File

@@ -72,6 +72,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts, const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths, const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data, const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight, const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const paddle::optional<paddle::Tensor>& k_norm_weight,
@@ -441,6 +442,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts, const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths, const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data, const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight, const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const paddle::optional<paddle::Tensor>& k_norm_weight,
@@ -479,6 +481,10 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.block_size = key_cache.dims()[2]; meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0]; meta_data.batch_size = seq_lens_this_time.dims()[0];
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> { auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>( return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data, meta_data,
@@ -514,6 +520,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp, cache_v_zp,
out_linear_shifts, out_linear_shifts,
out_linear_smooths, out_linear_smooths,
mask_offset,
kv_signal_data, kv_signal_data,
q_norm_weight, q_norm_weight,
k_norm_weight, k_norm_weight,
@@ -594,6 +601,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape, const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape, const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape, const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape, const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape, const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape, const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
@@ -657,6 +665,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& cache_v_zp_dtype, const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype, const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype, const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype, const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype, const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype, const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
@@ -738,6 +747,7 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"), paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"), paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"), paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"), paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"), paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")}) paddle::Optional("k_norm_weight")})

View File

@@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ tile_ids_per_batch, const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len, const int max_seq_len,
const int max_dec_len, const int max_dec_len,
const int max_block_num_per_seq, const int max_block_num_per_seq,
@@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel(
} else { } else {
o_base_ptr_int8 = out + o_offset; o_base_ptr_int8 = out + o_offset;
} }
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem); smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel(
kv_len - q_len + kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE, tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start))) chunk_start)))
: chunk_len) / : mask_offset ? 0 : chunk_len) /
(num_frags_z * 16); (num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8); 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -250,7 +252,8 @@ __global__ void multi_query_append_attention_kernel(
q_len, q_len,
kv_len, kv_len,
chunk_end, chunk_end,
s_frag); s_frag,
mask_offset_this_seq);
} }
// update m,d // update m,d
@@ -406,6 +409,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch, const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len, const int max_seq_len,
const int max_dec_len, const int max_dec_len,
const int max_block_num_per_seq, const int max_block_num_per_seq,
@@ -502,7 +506,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>(); tid % 8 * num_elems_per_128b<T>();
} }
} }
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem); smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -543,7 +547,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
kv_len - q_len + kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE, tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start))) chunk_start)))
: chunk_len) / : mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16); (NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -616,7 +620,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
q_len, q_len,
kv_len, kv_len,
chunk_end, chunk_end,
s_frag); s_frag,
mask_offset_this_seq);
} }
// update m,d // update m,d
@@ -882,6 +887,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -939,6 +945,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1103,6 +1110,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1171,6 +1179,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,

View File

@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ tile_ids_per_batch, const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len, const int max_seq_len,
const int max_dec_len, const int max_dec_len,
const int max_block_num_per_seq, const int max_block_num_per_seq,
@@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
} else { } else {
o_base_ptr_int8 = out + o_offset; o_base_ptr_int8 = out + o_offset;
} }
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem); smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel(
kv_len - q_len + kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE, tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start))) chunk_start)))
: chunk_len) / : mask_offset ? 0 : chunk_len) /
(num_frags_z * 16); (num_frags_z * 16);
uint32_t k_smem_offset_r = uint32_t k_smem_offset_r =
@@ -338,7 +340,8 @@ __global__ void multi_query_append_attention_c4_kernel(
q_len, q_len,
kv_len, kv_len,
chunk_end, chunk_end,
s_frag); s_frag,
mask_offset_this_seq);
} }
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>( update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -505,6 +508,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch, const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len, const int max_seq_len,
const int max_dec_len, const int max_dec_len,
const int max_block_num_per_seq, const int max_block_num_per_seq,
@@ -627,7 +631,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>(); tid % 8 * num_elems_per_128b<T>();
} }
} }
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem); smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -706,7 +710,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
kv_len - q_len + kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE, tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start))) chunk_start)))
: chunk_len) / : mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16); (NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = uint32_t k_smem_offset_r =
@@ -793,7 +797,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
q_len, q_len,
kv_len, kv_len,
chunk_end, chunk_end,
s_frag); s_frag,
mask_offset_this_seq);
} }
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>( update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1088,6 +1093,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1151,6 +1157,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1335,6 +1342,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1411,6 +1419,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,

View File

@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ tile_ids_per_batch, const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len, const int max_seq_len,
const int max_dec_len, const int max_dec_len,
const int max_block_num_per_seq, const int max_block_num_per_seq,
@@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
} else { } else {
o_base_ptr_int8 = out + o_offset; o_base_ptr_int8 = out + o_offset;
} }
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem); smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_len - q_len + kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE, tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start))) chunk_start)))
: chunk_len) / : mask_offset ? 0 : chunk_len) /
(num_frags_z * 16); (num_frags_z * 16);
uint32_t k_smem_offset_r = uint32_t k_smem_offset_r =
@@ -305,7 +307,8 @@ __global__ void multi_query_append_attention_c8_kernel(
q_len, q_len,
kv_len, kv_len,
chunk_end, chunk_end,
s_frag); s_frag,
mask_offset_this_seq);
} }
// update m,d // update m,d
@@ -474,6 +477,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch, const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len, const int max_seq_len,
const int max_dec_len, const int max_dec_len,
const int max_block_num_per_seq, const int max_block_num_per_seq,
@@ -601,7 +605,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>(); tid % 8 * num_elems_per_128b<T>();
} }
} }
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem); smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>( uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -642,7 +646,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_len - q_len + kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE, tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start))) chunk_start)))
: chunk_len) / : mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16); (NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = uint32_t k_smem_offset_r =
@@ -733,7 +737,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
q_len, q_len,
kv_len, kv_len,
chunk_end, chunk_end,
s_frag); s_frag,
mask_offset_this_seq);
} }
// update m,d // update m,d
@@ -1054,6 +1059,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1111,6 +1117,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1318,6 +1325,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,
@@ -1388,6 +1396,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(), tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
block_table.data<int>(), block_table.data<int>(),
meta_data.mask_offset,
max_seq_len, max_seq_len,
max_dec_len, max_dec_len,
max_block_num_per_seq, max_block_num_per_seq,

View File

@@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
const uint32_t qo_len, const uint32_t qo_len,
const uint32_t kv_len, const uint32_t kv_len,
const uint32_t chunk_end, const uint32_t chunk_end,
float (*s_frag)[num_frags_z][8]) { float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x; const uint32_t tx = threadIdx.x;
#pragma unroll #pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
group_size, group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2; 8 * (reg_id / 4) + reg_id % 2;
const bool out_of_boundary = bool out_of_boundary;
(causal if (mask_offset) {
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
: kv_idx >= chunk_end); } else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
if constexpr (std::is_same<T, half>::value) { if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] = s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];

View File

@@ -27,6 +27,7 @@ struct AppendAttnMetaData {
int head_dims; int head_dims;
int head_dims_v; int head_dims_v;
int max_blocks_per_seq; int max_blocks_per_seq;
const int *mask_offset = nullptr;
}; };
__forceinline__ __host__ __device__ int div_up(int a, int b) { __forceinline__ __host__ __device__ int div_up(int a, int b) {
@@ -477,6 +478,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
if (causal) { \ if (causal) { \
constexpr bool CAUSAL = true; \ constexpr bool CAUSAL = true; \
__VA_ARGS__ \ __VA_ARGS__ \
} else { \
constexpr bool CAUSAL = false; \
__VA_ARGS__ \
} }
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \ #define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \

View File

@@ -77,6 +77,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &cache_v_zp, const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts, const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths, const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data, const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight, const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const paddle::optional<paddle::Tensor>& k_norm_weight,

View File

@@ -62,6 +62,7 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
mask_offset: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
@@ -261,6 +262,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None), getattr(layer, "cache_v_zp", None),
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
metadata.mask_offset,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None), getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None), getattr(layer, "k_norm_weight", None),

View File

@@ -59,6 +59,7 @@ def append_attention(
cache_v_zp: Optional[paddle.Tensor] = None, cache_v_zp: Optional[paddle.Tensor] = None,
linear_shift: Optional[paddle.Tensor] = None, linear_shift: Optional[paddle.Tensor] = None,
linear_smooth: Optional[paddle.Tensor] = None, linear_smooth: Optional[paddle.Tensor] = None,
mask_offset: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None, kv_signal_data: Optional[paddle.Tensor] = None,
q_norm_weight: Optional[paddle.Tensor] = None, q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None, k_norm_weight: Optional[paddle.Tensor] = None,
@@ -116,6 +117,7 @@ def append_attention(
cache_v_zp, cache_v_zp,
linear_shift, linear_shift,
linear_smooth, linear_smooth,
mask_offset,
kv_signal_data, kv_signal_data,
q_norm_weight, q_norm_weight,
k_norm_weight, k_norm_weight,

View File

@@ -349,6 +349,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.rope_theta = 10000 self.rope_theta = 10000
self.dtype = "float16" self.dtype = "float16"
self.use_qk_norm = True self.use_qk_norm = True
self.use_mask_offset = False
self.init_tensor() self.init_tensor()
def init_tensor(self): def init_tensor(self):
@@ -404,6 +405,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cu_seqlens_k, self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time) ) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
self.token_num = self.padding_offset.shape[0] self.token_num = self.padding_offset.shape[0]
self.mask_offset = None
if self.use_mask_offset:
self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32")
for i in range(self.batch_size):
for j in range(self.seq_len):
self.mask_offset[i * self.seq_len + j] = j
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None): def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static() paddle.disable_static()
@@ -505,6 +512,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
None, # cache_v_zp None, # cache_v_zp
None, # linear_shift None, # linear_shift
None, # linear_smooth None, # linear_smooth
self.mask_offset, # mask_offset
None, # kv_signal_data None, # kv_signal_data
q_norm_weight, # q_norm_weight q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight k_norm_weight, # k_norm_weight
@@ -560,6 +568,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
# encoder # encoder
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len) # self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
self.seq_lens_this_time = self.seq_lens_encoder self.seq_lens_this_time = self.seq_lens_encoder
if self.use_mask_offset:
print("encoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(attn_mask=self.attention_mask) self.cmp_append_attention(attn_mask=self.attention_mask)
naive_cache_k, naive_cache_v = block_cache_to_naive_cache( naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k, self.cache_k,
@@ -590,6 +600,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cu_seqlens_q, self.cu_seqlens_q,
self.cu_seqlens_k, self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
if self.use_mask_offset:
self.mask_offset = paddle.full(self.batch_size, 0, "int32")
for i in range(self.batch_size):
self.mask_offset[i] = self.seq_lens_dec[i]
print("decoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(naive_cache_k, naive_cache_v, None) self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
@@ -614,6 +629,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.rope_theta = 10000 self.rope_theta = 10000
self.dtype = "float16" self.dtype = "float16"
self.use_qk_norm = False self.use_qk_norm = False
self.use_mask_offset = True
self.init_tensor() self.init_tensor()