mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
make append_attn supports mask_offset (#3138)
* make append_attn supports mask_offset * add unittest
This commit is contained in:
@@ -72,6 +72,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||||
|
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
@@ -441,6 +442,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||||
|
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
@@ -479,6 +481,10 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
meta_data.block_size = key_cache.dims()[2];
|
meta_data.block_size = key_cache.dims()[2];
|
||||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||||
|
|
||||||
|
if (mask_offset) {
|
||||||
|
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||||
|
}
|
||||||
|
|
||||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||||
meta_data,
|
meta_data,
|
||||||
@@ -514,6 +520,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
out_linear_shifts,
|
out_linear_shifts,
|
||||||
out_linear_smooths,
|
out_linear_smooths,
|
||||||
|
mask_offset,
|
||||||
kv_signal_data,
|
kv_signal_data,
|
||||||
q_norm_weight,
|
q_norm_weight,
|
||||||
k_norm_weight,
|
k_norm_weight,
|
||||||
@@ -594,6 +601,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
|||||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||||
@@ -657,6 +665,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
|||||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||||
@@ -738,6 +747,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
paddle::Optional("cache_v_zp"),
|
paddle::Optional("cache_v_zp"),
|
||||||
paddle::Optional("out_linear_shifts"),
|
paddle::Optional("out_linear_shifts"),
|
||||||
paddle::Optional("out_linear_smooths"),
|
paddle::Optional("out_linear_smooths"),
|
||||||
|
paddle::Optional("mask_offset"),
|
||||||
paddle::Optional("kv_signal_data"),
|
paddle::Optional("kv_signal_data"),
|
||||||
paddle::Optional("q_norm_weight"),
|
paddle::Optional("q_norm_weight"),
|
||||||
paddle::Optional("k_norm_weight")})
|
paddle::Optional("k_norm_weight")})
|
||||||
|
@@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
const int *__restrict__ tile_ids_per_batch,
|
const int *__restrict__ tile_ids_per_batch,
|
||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||||
|
const int *__restrict__ mask_offset,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int max_dec_len,
|
const int max_dec_len,
|
||||||
const int max_block_num_per_seq,
|
const int max_block_num_per_seq,
|
||||||
@@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
} else {
|
} else {
|
||||||
o_base_ptr_int8 = out + o_offset;
|
o_base_ptr_int8 = out + o_offset;
|
||||||
}
|
}
|
||||||
|
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||||
smem_t qo_smem(smem);
|
smem_t qo_smem(smem);
|
||||||
|
|
||||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
kv_len - q_len +
|
kv_len - q_len +
|
||||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||||
chunk_start)))
|
chunk_start)))
|
||||||
: chunk_len) /
|
: mask_offset ? 0 : chunk_len) /
|
||||||
(num_frags_z * 16);
|
(num_frags_z * 16);
|
||||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||||
@@ -250,7 +252,8 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
q_len,
|
q_len,
|
||||||
kv_len,
|
kv_len,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
s_frag);
|
s_frag,
|
||||||
|
mask_offset_this_seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update m,d
|
// update m,d
|
||||||
@@ -406,6 +409,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
const int *__restrict__ tile_ids_per_batch,
|
const int *__restrict__ tile_ids_per_batch,
|
||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||||
|
const int *__restrict__ mask_offset,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int max_dec_len,
|
const int max_dec_len,
|
||||||
const int max_block_num_per_seq,
|
const int max_block_num_per_seq,
|
||||||
@@ -502,7 +506,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
tid % 8 * num_elems_per_128b<T>();
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||||
smem_t qo_smem(smem);
|
smem_t qo_smem(smem);
|
||||||
|
|
||||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -543,7 +547,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
kv_len - q_len +
|
kv_len - q_len +
|
||||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||||
chunk_start)))
|
chunk_start)))
|
||||||
: chunk_len) /
|
: mask_offset ? 0 : chunk_len) /
|
||||||
(NUM_WARP_KV * num_frags_z * 16);
|
(NUM_WARP_KV * num_frags_z * 16);
|
||||||
|
|
||||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -616,7 +620,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
q_len,
|
q_len,
|
||||||
kv_len,
|
kv_len,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
s_frag);
|
s_frag,
|
||||||
|
mask_offset_this_seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update m,d
|
// update m,d
|
||||||
@@ -882,6 +887,7 @@ void MultiQueryAppendAttention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -939,6 +945,7 @@ void MultiQueryAppendAttention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1103,6 +1110,7 @@ void MultiQueryAppendAttention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1171,6 +1179,7 @@ void MultiQueryAppendAttention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
|
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
const int *__restrict__ tile_ids_per_batch,
|
const int *__restrict__ tile_ids_per_batch,
|
||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||||
|
const int *__restrict__ mask_offset,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int max_dec_len,
|
const int max_dec_len,
|
||||||
const int max_block_num_per_seq,
|
const int max_block_num_per_seq,
|
||||||
@@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
} else {
|
} else {
|
||||||
o_base_ptr_int8 = out + o_offset;
|
o_base_ptr_int8 = out + o_offset;
|
||||||
}
|
}
|
||||||
|
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||||
smem_t qo_smem(smem);
|
smem_t qo_smem(smem);
|
||||||
|
|
||||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
kv_len - q_len +
|
kv_len - q_len +
|
||||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||||
chunk_start)))
|
chunk_start)))
|
||||||
: chunk_len) /
|
: mask_offset ? 0 : chunk_len) /
|
||||||
(num_frags_z * 16);
|
(num_frags_z * 16);
|
||||||
|
|
||||||
uint32_t k_smem_offset_r =
|
uint32_t k_smem_offset_r =
|
||||||
@@ -338,7 +340,8 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
q_len,
|
q_len,
|
||||||
kv_len,
|
kv_len,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
s_frag);
|
s_frag,
|
||||||
|
mask_offset_this_seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||||
@@ -505,6 +508,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
const int *__restrict__ tile_ids_per_batch,
|
const int *__restrict__ tile_ids_per_batch,
|
||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||||
|
const int *__restrict__ mask_offset,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int max_dec_len,
|
const int max_dec_len,
|
||||||
const int max_block_num_per_seq,
|
const int max_block_num_per_seq,
|
||||||
@@ -627,7 +631,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
tid % 8 * num_elems_per_128b<T>();
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||||
smem_t qo_smem(smem);
|
smem_t qo_smem(smem);
|
||||||
|
|
||||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -706,7 +710,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
kv_len - q_len +
|
kv_len - q_len +
|
||||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||||
chunk_start)))
|
chunk_start)))
|
||||||
: chunk_len) /
|
: mask_offset ? 0 : chunk_len) /
|
||||||
(NUM_WARP_KV * num_frags_z * 16);
|
(NUM_WARP_KV * num_frags_z * 16);
|
||||||
|
|
||||||
uint32_t k_smem_offset_r =
|
uint32_t k_smem_offset_r =
|
||||||
@@ -793,7 +797,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
q_len,
|
q_len,
|
||||||
kv_len,
|
kv_len,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
s_frag);
|
s_frag,
|
||||||
|
mask_offset_this_seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||||
@@ -1088,6 +1093,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1151,6 +1157,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1335,6 +1342,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1411,6 +1419,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
|
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
const int *__restrict__ tile_ids_per_batch,
|
const int *__restrict__ tile_ids_per_batch,
|
||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||||
|
const int *__restrict__ mask_offset,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int max_dec_len,
|
const int max_dec_len,
|
||||||
const int max_block_num_per_seq,
|
const int max_block_num_per_seq,
|
||||||
@@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
} else {
|
} else {
|
||||||
o_base_ptr_int8 = out + o_offset;
|
o_base_ptr_int8 = out + o_offset;
|
||||||
}
|
}
|
||||||
|
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||||
smem_t qo_smem(smem);
|
smem_t qo_smem(smem);
|
||||||
|
|
||||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
kv_len - q_len +
|
kv_len - q_len +
|
||||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||||
chunk_start)))
|
chunk_start)))
|
||||||
: chunk_len) /
|
: mask_offset ? 0 : chunk_len) /
|
||||||
(num_frags_z * 16);
|
(num_frags_z * 16);
|
||||||
|
|
||||||
uint32_t k_smem_offset_r =
|
uint32_t k_smem_offset_r =
|
||||||
@@ -305,7 +307,8 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
q_len,
|
q_len,
|
||||||
kv_len,
|
kv_len,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
s_frag);
|
s_frag,
|
||||||
|
mask_offset_this_seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update m,d
|
// update m,d
|
||||||
@@ -474,6 +477,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
const int *__restrict__ tile_ids_per_batch,
|
const int *__restrict__ tile_ids_per_batch,
|
||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||||
|
const int *__restrict__ mask_offset,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int max_dec_len,
|
const int max_dec_len,
|
||||||
const int max_block_num_per_seq,
|
const int max_block_num_per_seq,
|
||||||
@@ -601,7 +605,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
tid % 8 * num_elems_per_128b<T>();
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||||
smem_t qo_smem(smem);
|
smem_t qo_smem(smem);
|
||||||
|
|
||||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||||
@@ -642,7 +646,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
kv_len - q_len +
|
kv_len - q_len +
|
||||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||||
chunk_start)))
|
chunk_start)))
|
||||||
: chunk_len) /
|
: mask_offset ? 0 : chunk_len) /
|
||||||
(NUM_WARP_KV * num_frags_z * 16);
|
(NUM_WARP_KV * num_frags_z * 16);
|
||||||
|
|
||||||
uint32_t k_smem_offset_r =
|
uint32_t k_smem_offset_r =
|
||||||
@@ -733,7 +737,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
q_len,
|
q_len,
|
||||||
kv_len,
|
kv_len,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
s_frag);
|
s_frag,
|
||||||
|
mask_offset_this_seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update m,d
|
// update m,d
|
||||||
@@ -1054,6 +1059,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1111,6 +1117,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1318,6 +1325,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
@@ -1388,6 +1396,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
tile_ids_per_batch.data<int>(),
|
tile_ids_per_batch.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
block_table.data<int>(),
|
block_table.data<int>(),
|
||||||
|
meta_data.mask_offset,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
max_block_num_per_seq,
|
max_block_num_per_seq,
|
||||||
|
@@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
|||||||
const uint32_t qo_len,
|
const uint32_t qo_len,
|
||||||
const uint32_t kv_len,
|
const uint32_t kv_len,
|
||||||
const uint32_t chunk_end,
|
const uint32_t chunk_end,
|
||||||
float (*s_frag)[num_frags_z][8]) {
|
float (*s_frag)[num_frags_z][8],
|
||||||
|
const int *mask_offset = nullptr) {
|
||||||
const uint32_t tx = threadIdx.x;
|
const uint32_t tx = threadIdx.x;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||||
@@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
|||||||
group_size,
|
group_size,
|
||||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||||
8 * (reg_id / 4) + reg_id % 2;
|
8 * (reg_id / 4) + reg_id % 2;
|
||||||
const bool out_of_boundary =
|
bool out_of_boundary;
|
||||||
|
if (mask_offset) {
|
||||||
|
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||||
|
} else {
|
||||||
|
out_of_boundary =
|
||||||
(causal
|
(causal
|
||||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||||
: kv_idx >= chunk_end);
|
: kv_idx >= chunk_end);
|
||||||
|
}
|
||||||
if constexpr (std::is_same<T, half>::value) {
|
if constexpr (std::is_same<T, half>::value) {
|
||||||
s_frag[fx][fz][reg_id] =
|
s_frag[fx][fz][reg_id] =
|
||||||
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
|
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
|
||||||
|
@@ -27,6 +27,7 @@ struct AppendAttnMetaData {
|
|||||||
int head_dims;
|
int head_dims;
|
||||||
int head_dims_v;
|
int head_dims_v;
|
||||||
int max_blocks_per_seq;
|
int max_blocks_per_seq;
|
||||||
|
const int *mask_offset = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ int div_up(int a, int b) {
|
__forceinline__ __host__ __device__ int div_up(int a, int b) {
|
||||||
@@ -477,6 +478,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
|||||||
if (causal) { \
|
if (causal) { \
|
||||||
constexpr bool CAUSAL = true; \
|
constexpr bool CAUSAL = true; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
|
} else { \
|
||||||
|
constexpr bool CAUSAL = false; \
|
||||||
|
__VA_ARGS__ \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
|
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
|
||||||
|
@@ -77,6 +77,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||||
|
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
@@ -62,6 +62,7 @@ class AppendAttentionMetadata(AttentionMetadata):
|
|||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
|
mask_offset: Optional[paddle.Tensor] = None
|
||||||
_fuse_kernel_compute_dtype: str = "bf16"
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
@@ -261,6 +262,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
getattr(layer, "cache_v_zp", None),
|
getattr(layer, "cache_v_zp", None),
|
||||||
layer.linear_shift,
|
layer.linear_shift,
|
||||||
layer.linear_smooth,
|
layer.linear_smooth,
|
||||||
|
metadata.mask_offset,
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
getattr(layer, "q_norm_weight", None),
|
getattr(layer, "q_norm_weight", None),
|
||||||
getattr(layer, "k_norm_weight", None),
|
getattr(layer, "k_norm_weight", None),
|
||||||
|
@@ -59,6 +59,7 @@ def append_attention(
|
|||||||
cache_v_zp: Optional[paddle.Tensor] = None,
|
cache_v_zp: Optional[paddle.Tensor] = None,
|
||||||
linear_shift: Optional[paddle.Tensor] = None,
|
linear_shift: Optional[paddle.Tensor] = None,
|
||||||
linear_smooth: Optional[paddle.Tensor] = None,
|
linear_smooth: Optional[paddle.Tensor] = None,
|
||||||
|
mask_offset: Optional[paddle.Tensor] = None,
|
||||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||||
q_norm_weight: Optional[paddle.Tensor] = None,
|
q_norm_weight: Optional[paddle.Tensor] = None,
|
||||||
k_norm_weight: Optional[paddle.Tensor] = None,
|
k_norm_weight: Optional[paddle.Tensor] = None,
|
||||||
@@ -116,6 +117,7 @@ def append_attention(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
linear_shift,
|
linear_shift,
|
||||||
linear_smooth,
|
linear_smooth,
|
||||||
|
mask_offset,
|
||||||
kv_signal_data,
|
kv_signal_data,
|
||||||
q_norm_weight,
|
q_norm_weight,
|
||||||
k_norm_weight,
|
k_norm_weight,
|
||||||
|
@@ -349,6 +349,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.rope_theta = 10000
|
self.rope_theta = 10000
|
||||||
self.dtype = "float16"
|
self.dtype = "float16"
|
||||||
self.use_qk_norm = True
|
self.use_qk_norm = True
|
||||||
|
self.use_mask_offset = False
|
||||||
self.init_tensor()
|
self.init_tensor()
|
||||||
|
|
||||||
def init_tensor(self):
|
def init_tensor(self):
|
||||||
@@ -404,6 +405,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.cu_seqlens_k,
|
self.cu_seqlens_k,
|
||||||
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
|
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
|
||||||
self.token_num = self.padding_offset.shape[0]
|
self.token_num = self.padding_offset.shape[0]
|
||||||
|
self.mask_offset = None
|
||||||
|
if self.use_mask_offset:
|
||||||
|
self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32")
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
for j in range(self.seq_len):
|
||||||
|
self.mask_offset[i * self.seq_len + j] = j
|
||||||
|
|
||||||
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
|
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
|
||||||
paddle.disable_static()
|
paddle.disable_static()
|
||||||
@@ -505,6 +512,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
None, # cache_v_zp
|
None, # cache_v_zp
|
||||||
None, # linear_shift
|
None, # linear_shift
|
||||||
None, # linear_smooth
|
None, # linear_smooth
|
||||||
|
self.mask_offset, # mask_offset
|
||||||
None, # kv_signal_data
|
None, # kv_signal_data
|
||||||
q_norm_weight, # q_norm_weight
|
q_norm_weight, # q_norm_weight
|
||||||
k_norm_weight, # k_norm_weight
|
k_norm_weight, # k_norm_weight
|
||||||
@@ -560,6 +568,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
# encoder
|
# encoder
|
||||||
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
|
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
|
||||||
self.seq_lens_this_time = self.seq_lens_encoder
|
self.seq_lens_this_time = self.seq_lens_encoder
|
||||||
|
if self.use_mask_offset:
|
||||||
|
print("encoder mask_offset: ", self.mask_offset)
|
||||||
self.cmp_append_attention(attn_mask=self.attention_mask)
|
self.cmp_append_attention(attn_mask=self.attention_mask)
|
||||||
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||||||
self.cache_k,
|
self.cache_k,
|
||||||
@@ -590,6 +600,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.cu_seqlens_q,
|
self.cu_seqlens_q,
|
||||||
self.cu_seqlens_k,
|
self.cu_seqlens_k,
|
||||||
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
|
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
|
||||||
|
if self.use_mask_offset:
|
||||||
|
self.mask_offset = paddle.full(self.batch_size, 0, "int32")
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
self.mask_offset[i] = self.seq_lens_dec[i]
|
||||||
|
print("decoder mask_offset: ", self.mask_offset)
|
||||||
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
|
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -614,6 +629,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
|||||||
self.rope_theta = 10000
|
self.rope_theta = 10000
|
||||||
self.dtype = "float16"
|
self.dtype = "float16"
|
||||||
self.use_qk_norm = False
|
self.use_qk_norm = False
|
||||||
|
self.use_mask_offset = True
|
||||||
self.init_tensor()
|
self.init_tensor()
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user