mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Add with_output version AppendAttention (#3302)
* get use_output from fd_config * add clear TODO description * add mask_offset para to align with develop * fix bug * fix use_output logic * fix sot bug
This commit is contained in:
@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
|
|||||||
|
|
||||||
|
|
||||||
template <paddle::DataType D>
|
template <paddle::DataType D>
|
||||||
std::vector<paddle::Tensor> AppendAttentionKernel(
|
void AppendAttentionKernel(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
const paddle::Tensor& qkv,
|
const paddle::Tensor& qkv,
|
||||||
const paddle::Tensor& key_cache,
|
const paddle::Tensor& key_cache,
|
||||||
@@ -60,6 +60,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
const paddle::Tensor& decoder_num_blocks,
|
const paddle::Tensor& decoder_num_blocks,
|
||||||
const paddle::Tensor& set_max_lengths,
|
const paddle::Tensor& set_max_lengths,
|
||||||
const paddle::Tensor& max_len_kv,
|
const paddle::Tensor& max_len_kv,
|
||||||
|
paddle::Tensor& fmha_out,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||||
@@ -122,27 +123,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
} else {
|
} else {
|
||||||
qkv_out = qkv;
|
qkv_out = qkv;
|
||||||
}
|
}
|
||||||
paddle::Tensor fmha_out;
|
|
||||||
if (out_linear_in_scale > 0.0) {
|
|
||||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
|
||||||
fmha_out = GetEmptyTensor(
|
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
|
||||||
paddle::DataType::INT8,
|
|
||||||
qkv.place());
|
|
||||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
|
||||||
fmha_out = GetEmptyTensor(
|
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
|
||||||
paddle::DataType::FLOAT8_E4M3FN,
|
|
||||||
qkv.place());
|
|
||||||
}else{
|
|
||||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmha_out = GetEmptyTensor(
|
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
|
||||||
D,
|
|
||||||
qkv.place());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
|
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
|
||||||
const paddle::Tensor& lambda_batch_ids,
|
const paddle::Tensor& lambda_batch_ids,
|
||||||
@@ -405,8 +385,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
cudaStreamWaitEvent(main_stream, decoder_event);
|
cudaStreamWaitEvent(main_stream, decoder_event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {fmha_out, qkv_out};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> AppendAttention(
|
std::vector<paddle::Tensor> AppendAttention(
|
||||||
@@ -481,12 +459,60 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
meta_data.block_size = key_cache.dims()[2];
|
meta_data.block_size = key_cache.dims()[2];
|
||||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||||
|
|
||||||
|
// template dtype generation
|
||||||
|
phi::DataType dtype_id;
|
||||||
|
switch (qkv.dtype()) {
|
||||||
|
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
|
||||||
|
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
|
||||||
|
case paddle::DataType::INT32: {
|
||||||
|
if (compute_dtype == "bf16") {
|
||||||
|
dtype_id = phi::DataType::BFLOAT16;
|
||||||
|
break;
|
||||||
|
} else if (compute_dtype == "fp16") {
|
||||||
|
dtype_id = phi::DataType::FLOAT16;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
PD_THROW(
|
||||||
|
"NOT supported data type. "
|
||||||
|
"Only float16 and bfloat16 are supported. ");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmha_out generation, rewrite from AppendAttentionKernel
|
||||||
|
paddle::Tensor fmha_out;
|
||||||
|
if (out_linear_in_scale > 0.0) {
|
||||||
|
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||||
|
fmha_out = GetEmptyTensor(
|
||||||
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
|
paddle::DataType::INT8,
|
||||||
|
qkv.place());
|
||||||
|
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||||
|
fmha_out = GetEmptyTensor(
|
||||||
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
|
paddle::DataType::FLOAT8_E4M3FN,
|
||||||
|
qkv.place());
|
||||||
|
} else{
|
||||||
|
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmha_out = GetEmptyTensor(
|
||||||
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
|
dtype_id,
|
||||||
|
qkv.place());
|
||||||
|
}
|
||||||
|
|
||||||
if (mask_offset) {
|
if (mask_offset) {
|
||||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||||
meta_data,
|
meta_data,
|
||||||
qkv,
|
qkv,
|
||||||
key_cache,
|
key_cache,
|
||||||
@@ -508,6 +534,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
decoder_num_blocks,
|
decoder_num_blocks,
|
||||||
set_max_lengths,
|
set_max_lengths,
|
||||||
max_len_kv,
|
max_len_kv,
|
||||||
|
fmha_out,
|
||||||
rotary_embs,
|
rotary_embs,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
qkv_bias,
|
qkv_bias,
|
||||||
@@ -539,20 +566,183 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
speculate_max_draft_token_num,
|
speculate_max_draft_token_num,
|
||||||
causal,
|
causal,
|
||||||
speculate_decoder);
|
speculate_decoder);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
phi::dtype::float16 fp16_dtype;
|
||||||
|
phi::dtype::bfloat16 bp16_dtype;
|
||||||
|
switch (dtype_id){
|
||||||
|
case phi::DataType::FLOAT16: {
|
||||||
|
dispatch_by_template(fp16_dtype);
|
||||||
|
return {fmha_out};
|
||||||
|
}
|
||||||
|
case phi::DataType::BFLOAT16: {
|
||||||
|
dispatch_by_template(bp16_dtype);
|
||||||
|
return {fmha_out};
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
PD_THROW(
|
||||||
|
"NOT supported data type. "
|
||||||
|
"Only float16 and bfloat16 are supported. ");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {paddle::Tensor{}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void AppendAttentionWithOutput(
|
||||||
|
const paddle::Tensor& qkv,
|
||||||
|
const paddle::Tensor& key_cache,
|
||||||
|
const paddle::Tensor& value_cache,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_tables,
|
||||||
|
const paddle::Tensor& encoder_batch_ids,
|
||||||
|
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||||
|
const paddle::Tensor& encoder_num_blocks,
|
||||||
|
const paddle::Tensor& kv_batch_ids,
|
||||||
|
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||||
|
const paddle::Tensor& kv_num_blocks,
|
||||||
|
const paddle::Tensor& decoder_batch_ids,
|
||||||
|
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||||
|
const paddle::Tensor& decoder_num_blocks,
|
||||||
|
const paddle::Tensor& set_max_lengths,
|
||||||
|
const paddle::Tensor& max_len_kv,
|
||||||
|
paddle::Tensor& fmha_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||||
|
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||||
|
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||||
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
|
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||||
|
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||||
|
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||||
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps,
|
||||||
|
const std::string& compute_dtype,
|
||||||
|
const std::string& cache_quant_type_str,
|
||||||
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
|
const int max_input_length,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float out_linear_in_scale,
|
||||||
|
const int encoder_block_shape_q,
|
||||||
|
const int decoder_block_shape_q,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool speculate_decoder) {
|
||||||
|
AppendAttnMetaData meta_data;
|
||||||
|
|
||||||
|
const auto& qkv_dims = qkv.dims();
|
||||||
|
const auto& key_cache_dims = key_cache.dims();
|
||||||
|
meta_data.token_nums = qkv_dims[0];
|
||||||
|
meta_data.kv_num_heads = key_cache_dims[1];
|
||||||
|
meta_data.head_dims = key_cache_dims[3];
|
||||||
|
// TODO: trick method support c4, add attr head_dims in the future
|
||||||
|
if (cache_quant_type_str == "cache_int4_zp") {
|
||||||
|
meta_data.head_dims *= 2;
|
||||||
|
}
|
||||||
|
const int total_num_head =
|
||||||
|
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
|
||||||
|
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
|
||||||
|
|
||||||
|
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||||
|
meta_data.block_size = key_cache.dims()[2];
|
||||||
|
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||||
|
|
||||||
|
if (mask_offset) {
|
||||||
|
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||||
|
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||||
|
meta_data,
|
||||||
|
qkv,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
block_tables,
|
||||||
|
encoder_batch_ids,
|
||||||
|
encoder_tile_ids_per_batch,
|
||||||
|
encoder_num_blocks,
|
||||||
|
kv_batch_ids,
|
||||||
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks,
|
||||||
|
decoder_batch_ids,
|
||||||
|
decoder_tile_ids_per_batch,
|
||||||
|
decoder_num_blocks,
|
||||||
|
set_max_lengths,
|
||||||
|
max_len_kv,
|
||||||
|
fmha_out,
|
||||||
|
rotary_embs,
|
||||||
|
attn_mask,
|
||||||
|
qkv_bias,
|
||||||
|
qkv_out_scales,
|
||||||
|
cache_k_quant_scales,
|
||||||
|
cache_v_quant_scales,
|
||||||
|
cache_k_dequant_scales,
|
||||||
|
cache_v_dequant_scales,
|
||||||
|
cache_k_zp,
|
||||||
|
cache_v_zp,
|
||||||
|
out_linear_shifts,
|
||||||
|
out_linear_smooths,
|
||||||
|
mask_offset,
|
||||||
|
kv_signal_data,
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps,
|
||||||
|
cache_quant_type_str,
|
||||||
|
use_neox_rotary_style,
|
||||||
|
rope_3d,
|
||||||
|
max_input_length,
|
||||||
|
quant_max_bound,
|
||||||
|
quant_min_bound,
|
||||||
|
out_linear_in_scale,
|
||||||
|
encoder_block_shape_q,
|
||||||
|
decoder_block_shape_q,
|
||||||
|
max_partition_size,
|
||||||
|
encoder_max_partition_size,
|
||||||
|
speculate_max_draft_token_num,
|
||||||
|
causal,
|
||||||
|
speculate_decoder);
|
||||||
};
|
};
|
||||||
|
|
||||||
phi::dtype::float16 fp16_dtype;
|
phi::dtype::float16 fp16_dtype;
|
||||||
phi::dtype::bfloat16 bp16_dtype;
|
phi::dtype::bfloat16 bp16_dtype;
|
||||||
|
|
||||||
switch (qkv.dtype()) {
|
switch (qkv.dtype()) {
|
||||||
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
|
case paddle::DataType::FLOAT16: {
|
||||||
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
|
dispatch_by_template(fp16_dtype);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case paddle::DataType::BFLOAT16: {
|
||||||
|
dispatch_by_template(bp16_dtype);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case paddle::DataType::INT32: {
|
case paddle::DataType::INT32: {
|
||||||
if (compute_dtype == "bf16") {
|
if (compute_dtype == "bf16") {
|
||||||
return dispatch_by_template(bp16_dtype);
|
dispatch_by_template(bp16_dtype);
|
||||||
|
break;
|
||||||
} else if (compute_dtype == "fp16") {
|
} else if (compute_dtype == "fp16") {
|
||||||
return dispatch_by_template(fp16_dtype);
|
dispatch_by_template(fp16_dtype);
|
||||||
|
break;
|
||||||
} else {
|
} else {
|
||||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||||
break;
|
break;
|
||||||
@@ -565,9 +755,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {paddle::Tensor{}};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||||
const std::vector<int64_t>& qkv_shape,
|
const std::vector<int64_t>& qkv_shape,
|
||||||
const std::vector<int64_t>& key_cache_shape,
|
const std::vector<int64_t>& key_cache_shape,
|
||||||
@@ -629,7 +819,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
|||||||
}
|
}
|
||||||
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
|
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
|
||||||
const int num_heads = total_num_head - 2 * kv_num_heads;
|
const int num_heads = total_num_head - 2 * kv_num_heads;
|
||||||
return {{token_num, num_heads * head_dim}, qkv_shape};
|
return {{token_num, num_heads * head_dim}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> AppendAttentionInferDtype(
|
std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||||
@@ -688,32 +878,148 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
|||||||
if (compute_dtype == "bf16") {
|
if (compute_dtype == "bf16") {
|
||||||
if (out_linear_in_scale > 0.0) {
|
if (out_linear_in_scale > 0.0) {
|
||||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||||
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
|
return {paddle::DataType::INT8};
|
||||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
|
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||||
}else{
|
}else{
|
||||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
|
return {paddle::DataType::BFLOAT16};
|
||||||
}
|
}
|
||||||
} else if (compute_dtype == "fp16") {
|
} else if (compute_dtype == "fp16") {
|
||||||
if (out_linear_in_scale > 0.0) {
|
if (out_linear_in_scale > 0.0) {
|
||||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
|
return {paddle::DataType::INT8};
|
||||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
|
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||||
}else{
|
}else{
|
||||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
|
return {paddle::DataType::FLOAT16};
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||||
|
const std::vector<int64_t>& qkv_shape,
|
||||||
|
const std::vector<int64_t>& key_cache_shape,
|
||||||
|
const std::vector<int64_t>& value_cache_shape,
|
||||||
|
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||||
|
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||||
|
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||||
|
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||||
|
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||||
|
const std::vector<int64_t>& block_tables_shape,
|
||||||
|
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||||
|
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||||
|
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||||
|
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||||
|
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||||
|
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||||
|
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||||
|
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||||
|
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||||
|
const std::vector<int64_t>& set_max_lengths_shape,
|
||||||
|
const std::vector<int64_t>& max_len_kv_shape,
|
||||||
|
const std::vector<int64_t>& fmha_out_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||||
|
const float rms_norm_eps,
|
||||||
|
const std::string& compute_dtype,
|
||||||
|
const std::string& cache_quant_type_str,
|
||||||
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
|
const int max_input_length,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float out_linear_in_scale,
|
||||||
|
const int encoder_block_shape_q,
|
||||||
|
const int decoder_block_shape_q,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool speculate_decoder) {
|
||||||
|
return {fmha_out_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||||
|
const paddle::DataType& qkv_dtype,
|
||||||
|
const paddle::DataType& key_cache_dtype,
|
||||||
|
const paddle::DataType& value_cache_dtype,
|
||||||
|
const paddle::DataType& seq_lens_encoder_dtype,
|
||||||
|
const paddle::DataType& seq_lens_decoder_dtype,
|
||||||
|
const paddle::DataType& seq_lens_this_time_dtype,
|
||||||
|
const paddle::DataType& batch_id_per_token_dtype,
|
||||||
|
const paddle::DataType& cu_seqlens_q_dtype,
|
||||||
|
const paddle::DataType& block_tables_dtype,
|
||||||
|
const paddle::DataType& encoder_batch_ids_dtype,
|
||||||
|
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||||
|
const paddle::DataType& encoder_num_blocks_dtype,
|
||||||
|
const paddle::DataType& kv_batch_ids_dtype,
|
||||||
|
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||||
|
const paddle::DataType& kv_num_blocks_dtype,
|
||||||
|
const paddle::DataType& decoder_batch_ids_dtype,
|
||||||
|
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||||
|
const paddle::DataType& decoder_num_blocks_dtype,
|
||||||
|
const paddle::DataType& set_max_lengths_dtype,
|
||||||
|
const paddle::DataType& max_len_kv_dtype,
|
||||||
|
const paddle::DataType& fmha_out_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||||
|
const float rms_norm_eps,
|
||||||
|
const std::string& compute_dtype,
|
||||||
|
const std::string& cache_quant_type_str,
|
||||||
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
|
const int max_input_length,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float out_linear_in_scale,
|
||||||
|
const int encoder_block_shape_q,
|
||||||
|
const int decoder_block_shape_q,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool speculate_decoder) {
|
||||||
|
return {fmha_out_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(append_attention)
|
PD_BUILD_STATIC_OP(append_attention)
|
||||||
.Inputs({"qkv",
|
.Inputs({"qkv",
|
||||||
"key_cache",
|
"key_cache",
|
||||||
@@ -751,7 +1057,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
paddle::Optional("kv_signal_data"),
|
paddle::Optional("kv_signal_data"),
|
||||||
paddle::Optional("q_norm_weight"),
|
paddle::Optional("q_norm_weight"),
|
||||||
paddle::Optional("k_norm_weight")})
|
paddle::Optional("k_norm_weight")})
|
||||||
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
|
||||||
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||||
{"value_cache", "value_cache_out"}})
|
{"value_cache", "value_cache_out"}})
|
||||||
.Attrs({"rms_norm_eps: float",
|
.Attrs({"rms_norm_eps: float",
|
||||||
@@ -774,3 +1080,66 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||||
|
.Inputs({"qkv",
|
||||||
|
"key_cache",
|
||||||
|
"value_cache",
|
||||||
|
"seq_lens_encoder",
|
||||||
|
"seq_lens_decoder",
|
||||||
|
"seq_lens_this_time",
|
||||||
|
"batch_id_per_token",
|
||||||
|
"cu_seqlens_q",
|
||||||
|
"block_tables",
|
||||||
|
"encoder_batch_ids",
|
||||||
|
"encoder_tile_ids_per_batch",
|
||||||
|
"encoder_num_blocks",
|
||||||
|
"kv_batch_ids",
|
||||||
|
"kv_tile_ids_per_batch",
|
||||||
|
"kv_num_blocks",
|
||||||
|
"decoder_batch_ids",
|
||||||
|
"decoder_tile_ids_per_batch",
|
||||||
|
"decoder_num_blocks",
|
||||||
|
"set_max_lengths",
|
||||||
|
"max_len_kv",
|
||||||
|
"fmha_out",
|
||||||
|
paddle::Optional("rotary_embs"),
|
||||||
|
paddle::Optional("attn_mask"),
|
||||||
|
paddle::Optional("qkv_bias"),
|
||||||
|
paddle::Optional("qkv_out_scales"),
|
||||||
|
paddle::Optional("cache_k_quant_scales"),
|
||||||
|
paddle::Optional("cache_v_quant_scales"),
|
||||||
|
paddle::Optional("cache_k_dequant_scales"),
|
||||||
|
paddle::Optional("cache_v_dequant_scales"),
|
||||||
|
paddle::Optional("cache_k_zp"),
|
||||||
|
paddle::Optional("cache_v_zp"),
|
||||||
|
paddle::Optional("out_linear_shifts"),
|
||||||
|
paddle::Optional("out_linear_smooths"),
|
||||||
|
paddle::Optional("mask_offset"),
|
||||||
|
paddle::Optional("kv_signal_data"),
|
||||||
|
paddle::Optional("q_norm_weight"),
|
||||||
|
paddle::Optional("k_norm_weight")})
|
||||||
|
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||||
|
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
|
||||||
|
{"key_cache", "key_cache_out"},
|
||||||
|
{"value_cache", "value_cache_out"}})
|
||||||
|
.Attrs({"rms_norm_eps: float",
|
||||||
|
"compute_type: std::string",
|
||||||
|
"cache_quant_type: std::string",
|
||||||
|
"use_neox_rotary_style: bool",
|
||||||
|
"rope_3d: bool",
|
||||||
|
"max_input_length: int",
|
||||||
|
"quant_max_bound: float",
|
||||||
|
"quant_min_bound: float",
|
||||||
|
"out_linear_in_scale: float",
|
||||||
|
"encoder_block_shape_q: int",
|
||||||
|
"decoder_block_shape_q: int",
|
||||||
|
"max_partition_size: int",
|
||||||
|
"encoder_max_partition_size: int",
|
||||||
|
"speculate_max_draft_token_num: int",
|
||||||
|
"causal: bool",
|
||||||
|
"speculate_decoder: bool",
|
||||||
|
})
|
||||||
|
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));
|
||||||
|
@@ -91,6 +91,49 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const int speculate_max_draft_token_num, const bool causal,
|
const int speculate_max_draft_token_num, const bool causal,
|
||||||
const bool speculate_decoder);
|
const bool speculate_decoder);
|
||||||
|
|
||||||
|
void AppendAttentionWithOutput(
|
||||||
|
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||||
|
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
|
||||||
|
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
|
||||||
|
const paddle::Tensor &encoder_tile_ids_per_batch,
|
||||||
|
const paddle::Tensor &encoder_num_blocks,
|
||||||
|
const paddle::Tensor &kv_batch_ids,
|
||||||
|
const paddle::Tensor &kv_tile_ids_per_batch,
|
||||||
|
const paddle::Tensor &kv_num_blocks,
|
||||||
|
const paddle::Tensor &decoder_batch_ids,
|
||||||
|
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||||
|
const paddle::Tensor &decoder_num_blocks,
|
||||||
|
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||||
|
paddle::Tensor &fmha_out,
|
||||||
|
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||||
|
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor> &qkv_bias,
|
||||||
|
const paddle::optional<paddle::Tensor> &qkv_out_scales,
|
||||||
|
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
|
||||||
|
const paddle::optional<paddle::Tensor> &cache_k_zp,
|
||||||
|
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||||
|
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||||
|
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||||
|
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||||
|
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps,
|
||||||
|
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||||
|
const bool use_neox_rotary_style, const bool rope_3d,
|
||||||
|
const int max_input_length, const float quant_max_bound,
|
||||||
|
const float quant_min_bound, const float out_linear_in_scale,
|
||||||
|
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||||
|
const int max_partition_size, const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num, const bool causal,
|
||||||
|
const bool speculate_decoder);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||||
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
|
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
|
||||||
@@ -881,6 +924,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
* append_attention
|
* append_attention
|
||||||
*/
|
*/
|
||||||
m.def("append_attention", &AppendAttention, "append attention function");
|
m.def("append_attention", &AppendAttention, "append attention function");
|
||||||
|
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
|
||||||
/**
|
/**
|
||||||
* gqa_rope_write_cache.cu
|
* gqa_rope_write_cache.cu
|
||||||
* gqa_rope_write_cache
|
* gqa_rope_write_cache
|
||||||
|
@@ -24,6 +24,7 @@ import paddle
|
|||||||
|
|
||||||
from fastdeploy.model_executor.layers.attention.ops import (
|
from fastdeploy.model_executor.layers.attention.ops import (
|
||||||
append_attention,
|
append_attention,
|
||||||
|
append_attention_with_output,
|
||||||
get_block_shape_and_split_kv_block,
|
get_block_shape_and_split_kv_block,
|
||||||
init_kv_signal_per_query,
|
init_kv_signal_per_query,
|
||||||
init_signal_layerwise,
|
init_signal_layerwise,
|
||||||
@@ -122,6 +123,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
fd_config.parallel_config.expert_parallel_rank = 0
|
fd_config.parallel_config.expert_parallel_rank = 0
|
||||||
|
|
||||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
|
self.use_output = not fd_config.graph_opt_config.full_cuda_graph
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
@@ -229,58 +231,149 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
layer.layer_id + self.start_layer_index,
|
layer.layer_id + self.start_layer_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
res = append_attention(
|
if self.use_output:
|
||||||
qkv,
|
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
|
||||||
forward_meta.caches[2 * layer.layer_id],
|
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")
|
||||||
forward_meta.caches[2 * layer.layer_id + 1],
|
compute_type = metadata._fuse_kernel_compute_dtype
|
||||||
forward_meta.seq_lens_encoder,
|
out_scale = getattr(layer, "out_scale", -1.0)
|
||||||
forward_meta.seq_lens_decoder,
|
# 1. get output datatype
|
||||||
forward_meta.seq_lens_this_time,
|
qkv_dtype = qkv.dtype
|
||||||
forward_meta.batch_id_per_token,
|
if qkv_dtype == paddle.float16:
|
||||||
forward_meta.cu_seqlens_q,
|
D_type = paddle.float16
|
||||||
metadata.block_tables,
|
elif qkv_dtype == paddle.bfloat16:
|
||||||
metadata.encoder_batch_ids,
|
D_type = paddle.bfloat16
|
||||||
metadata.encoder_tile_ids_per_batch,
|
elif qkv_dtype == paddle.int32:
|
||||||
metadata.encoder_num_blocks,
|
if compute_type == "bf16":
|
||||||
metadata.kv_batch_ids,
|
D_type = paddle.bfloat16
|
||||||
metadata.kv_tile_ids_per_batch,
|
elif compute_type == "fp16":
|
||||||
metadata.kv_num_blocks,
|
D_type = paddle.float16
|
||||||
forward_meta.decoder_batch_ids,
|
else:
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].")
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
else:
|
||||||
forward_meta.max_len_tensor_cpu,
|
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].")
|
||||||
metadata.max_len_kv,
|
# 2.Extract related parameters
|
||||||
metadata.rotary_embs,
|
token_nums = qkv.shape[0]
|
||||||
metadata.attn_mask,
|
head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2
|
||||||
layer.qkv_bias,
|
q_num_heads = self.num_heads
|
||||||
layer.qkv_scale,
|
# 3. generate output tensor of different dtypes
|
||||||
getattr(layer, "cache_k_scale", None),
|
if out_scale > 0.0:
|
||||||
getattr(layer, "cache_v_scale", None),
|
if abs(quant_max_bound - 127) < 0.000001:
|
||||||
getattr(layer, "cache_k_out_scale", None),
|
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
|
||||||
getattr(layer, "cache_v_out_scale", None),
|
elif abs(quant_max_bound - 448) < 0.000001:
|
||||||
getattr(layer, "cache_k_zp", None),
|
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
|
||||||
getattr(layer, "cache_v_zp", None),
|
else:
|
||||||
layer.linear_shift,
|
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
|
||||||
layer.linear_smooth,
|
else:
|
||||||
metadata.mask_offset,
|
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
|
||||||
getattr(layer, "q_norm_weight", None),
|
append_attention_with_output(
|
||||||
getattr(layer, "k_norm_weight", None),
|
qkv,
|
||||||
getattr(layer, "rms_norm_eps", 1e-6),
|
forward_meta.caches[2 * layer.layer_id],
|
||||||
metadata._fuse_kernel_compute_dtype,
|
forward_meta.caches[2 * layer.layer_id + 1],
|
||||||
getattr(layer, "cache_quant_type_str", "none"),
|
forward_meta.seq_lens_encoder,
|
||||||
layer.use_neox_rotary_style,
|
forward_meta.seq_lens_decoder,
|
||||||
self.rope_3d,
|
forward_meta.seq_lens_this_time,
|
||||||
self.max_seq_len,
|
forward_meta.batch_id_per_token,
|
||||||
getattr(layer, "quant_max_bound", 0.0),
|
forward_meta.cu_seqlens_q,
|
||||||
getattr(layer, "quant_min_bound", 0.0),
|
metadata.block_tables,
|
||||||
getattr(layer, "out_scale", -1.0),
|
metadata.encoder_batch_ids,
|
||||||
self.encoder_block_shape_q,
|
metadata.encoder_tile_ids_per_batch,
|
||||||
self.decoder_block_shape_q,
|
metadata.encoder_num_blocks,
|
||||||
metadata.max_partition_size,
|
metadata.kv_batch_ids,
|
||||||
metadata.encoder_max_partition_size,
|
metadata.kv_tile_ids_per_batch,
|
||||||
self.speculate_max_draft_token_num + 1,
|
metadata.kv_num_blocks,
|
||||||
self.causal,
|
forward_meta.decoder_batch_ids,
|
||||||
self.speculative_method is not None,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
)[0]
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
metadata.max_len_kv,
|
||||||
|
res,
|
||||||
|
metadata.rotary_embs,
|
||||||
|
metadata.attn_mask,
|
||||||
|
layer.qkv_bias,
|
||||||
|
layer.qkv_scale,
|
||||||
|
getattr(layer, "cache_k_scale", None),
|
||||||
|
getattr(layer, "cache_v_scale", None),
|
||||||
|
getattr(layer, "cache_k_out_scale", None),
|
||||||
|
getattr(layer, "cache_v_out_scale", None),
|
||||||
|
getattr(layer, "cache_k_zp", None),
|
||||||
|
getattr(layer, "cache_v_zp", None),
|
||||||
|
layer.linear_shift,
|
||||||
|
layer.linear_smooth,
|
||||||
|
metadata.mask_offset,
|
||||||
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
|
getattr(layer, "q_norm_weight", None),
|
||||||
|
getattr(layer, "k_norm_weight", None),
|
||||||
|
getattr(layer, "rms_norm_eps", 1e-6),
|
||||||
|
metadata._fuse_kernel_compute_dtype,
|
||||||
|
getattr(layer, "cache_quant_type_str", "none"),
|
||||||
|
layer.use_neox_rotary_style,
|
||||||
|
self.rope_3d,
|
||||||
|
self.max_seq_len,
|
||||||
|
getattr(layer, "quant_max_bound", 0.0),
|
||||||
|
getattr(layer, "quant_min_bound", 0.0),
|
||||||
|
getattr(layer, "out_scale", -1.0),
|
||||||
|
self.encoder_block_shape_q,
|
||||||
|
self.decoder_block_shape_q,
|
||||||
|
metadata.max_partition_size,
|
||||||
|
metadata.encoder_max_partition_size,
|
||||||
|
self.speculate_max_draft_token_num + 1,
|
||||||
|
self.causal,
|
||||||
|
self.speculative_method is not None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = append_attention(
|
||||||
|
qkv,
|
||||||
|
forward_meta.caches[2 * layer.layer_id],
|
||||||
|
forward_meta.caches[2 * layer.layer_id + 1],
|
||||||
|
forward_meta.seq_lens_encoder,
|
||||||
|
forward_meta.seq_lens_decoder,
|
||||||
|
forward_meta.seq_lens_this_time,
|
||||||
|
forward_meta.batch_id_per_token,
|
||||||
|
forward_meta.cu_seqlens_q,
|
||||||
|
metadata.block_tables,
|
||||||
|
metadata.encoder_batch_ids,
|
||||||
|
metadata.encoder_tile_ids_per_batch,
|
||||||
|
metadata.encoder_num_blocks,
|
||||||
|
metadata.kv_batch_ids,
|
||||||
|
metadata.kv_tile_ids_per_batch,
|
||||||
|
metadata.kv_num_blocks,
|
||||||
|
forward_meta.decoder_batch_ids,
|
||||||
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
metadata.max_len_kv,
|
||||||
|
metadata.rotary_embs,
|
||||||
|
metadata.attn_mask,
|
||||||
|
layer.qkv_bias,
|
||||||
|
layer.qkv_scale,
|
||||||
|
getattr(layer, "cache_k_scale", None),
|
||||||
|
getattr(layer, "cache_v_scale", None),
|
||||||
|
getattr(layer, "cache_k_out_scale", None),
|
||||||
|
getattr(layer, "cache_v_out_scale", None),
|
||||||
|
getattr(layer, "cache_k_zp", None),
|
||||||
|
getattr(layer, "cache_v_zp", None),
|
||||||
|
layer.linear_shift,
|
||||||
|
layer.linear_smooth,
|
||||||
|
metadata.mask_offset,
|
||||||
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
|
getattr(layer, "q_norm_weight", None),
|
||||||
|
getattr(layer, "k_norm_weight", None),
|
||||||
|
getattr(layer, "rms_norm_eps", 1e-6),
|
||||||
|
metadata._fuse_kernel_compute_dtype,
|
||||||
|
getattr(layer, "cache_quant_type_str", "none"),
|
||||||
|
layer.use_neox_rotary_style,
|
||||||
|
self.rope_3d,
|
||||||
|
self.max_seq_len,
|
||||||
|
getattr(layer, "quant_max_bound", 0.0),
|
||||||
|
getattr(layer, "quant_min_bound", 0.0),
|
||||||
|
getattr(layer, "out_scale", -1.0),
|
||||||
|
self.encoder_block_shape_q,
|
||||||
|
self.decoder_block_shape_q,
|
||||||
|
metadata.max_partition_size,
|
||||||
|
metadata.encoder_max_partition_size,
|
||||||
|
self.speculate_max_draft_token_num + 1,
|
||||||
|
self.causal,
|
||||||
|
self.speculative_method is not None,
|
||||||
|
)
|
||||||
return res
|
return res
|
||||||
|
@@ -378,7 +378,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.speculate_max_draft_token_num + 1,
|
self.speculate_max_draft_token_num + 1,
|
||||||
self.causal,
|
self.causal,
|
||||||
self.speculative_method is not None,
|
self.speculative_method is not None,
|
||||||
)[0]
|
)
|
||||||
|
|
||||||
if metadata.max_len_tensor_cpu[1] > 0:
|
if metadata.max_len_tensor_cpu[1] > 0:
|
||||||
merge_prefill_decode_output(
|
merge_prefill_decode_output(
|
||||||
|
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .append_attention import append_attention
|
from .append_attention import append_attention, append_attention_with_output
|
||||||
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
|
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
|
||||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||||
from .init_kv_signal_per_query import init_kv_signal_per_query
|
from .init_kv_signal_per_query import init_kv_signal_per_query
|
||||||
@@ -25,6 +25,7 @@ from .pre_cache_len_concat import pre_cache_len_concat
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"get_block_shape_and_split_kv_block",
|
"get_block_shape_and_split_kv_block",
|
||||||
"append_attention",
|
"append_attention",
|
||||||
|
"append_attention_with_output",
|
||||||
"open_shm_and_get_meta_signal",
|
"open_shm_and_get_meta_signal",
|
||||||
"init_signal_layerwise",
|
"init_signal_layerwise",
|
||||||
"gqa_rope_write_cache",
|
"gqa_rope_write_cache",
|
||||||
|
@@ -24,6 +24,9 @@ if current_platform.is_cuda():
|
|||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
append_attention as append_attention_gpu,
|
append_attention as append_attention_gpu,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
append_attention_with_output as append_attention_with_output_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def append_attention(
|
def append_attention(
|
||||||
@@ -141,3 +144,124 @@ def append_attention(
|
|||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: (mengyuan) merge w/o output version append attention after
|
||||||
|
# finishing developing sub-graph cudagraph capture to reduce
|
||||||
|
# compilation volume
|
||||||
|
def append_attention_with_output(
|
||||||
|
qkv: paddle.Tensor,
|
||||||
|
key_cache: paddle.Tensor,
|
||||||
|
value_cache: paddle.Tensor,
|
||||||
|
seq_lens_encoder: paddle.Tensor,
|
||||||
|
seq_lens_decoder: paddle.Tensor,
|
||||||
|
seq_lens_this_time: paddle.Tensor,
|
||||||
|
batch_id_per_token: paddle.Tensor,
|
||||||
|
cu_seqlens_q: paddle.Tensor,
|
||||||
|
block_tables: paddle.Tensor,
|
||||||
|
encoder_batch_ids: paddle.Tensor,
|
||||||
|
encoder_tile_ids_per_batch: paddle.Tensor,
|
||||||
|
encoder_num_blocks: paddle.Tensor,
|
||||||
|
kv_batch_ids: paddle.Tensor,
|
||||||
|
kv_tile_ids_per_batch: paddle.Tensor,
|
||||||
|
kv_num_blocks: paddle.Tensor,
|
||||||
|
decoder_batch_ids: paddle.Tensor,
|
||||||
|
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||||
|
decoder_num_blocks: paddle.Tensor,
|
||||||
|
set_max_lengths: paddle.Tensor,
|
||||||
|
max_len_kv: paddle.Tensor,
|
||||||
|
out: paddle.tensor, # attention output
|
||||||
|
rotary_embs: Optional[paddle.Tensor] = None,
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None,
|
||||||
|
qkv_bias: Optional[paddle.Tensor] = None,
|
||||||
|
qkv_scale: Optional[paddle.Tensor] = None,
|
||||||
|
k_quant_scale: Optional[paddle.Tensor] = None,
|
||||||
|
v_quant_scale: Optional[paddle.Tensor] = None,
|
||||||
|
k_dequant_scale: Optional[paddle.Tensor] = None,
|
||||||
|
v_dequant_scale: Optional[paddle.Tensor] = None,
|
||||||
|
cache_k_zp: Optional[paddle.Tensor] = None,
|
||||||
|
cache_v_zp: Optional[paddle.Tensor] = None,
|
||||||
|
linear_shift: Optional[paddle.Tensor] = None,
|
||||||
|
linear_smooth: Optional[paddle.Tensor] = None,
|
||||||
|
mask_offset: Optional[paddle.Tensor] = None,
|
||||||
|
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||||
|
q_norm_weight: Optional[paddle.Tensor] = None,
|
||||||
|
k_norm_weight: Optional[paddle.Tensor] = None,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
compute_type: str = "bf16",
|
||||||
|
cache_quant_type: str = "none",
|
||||||
|
use_neox_rotary_style: bool = False,
|
||||||
|
rope_3d: bool = False,
|
||||||
|
max_input_length: int = 0,
|
||||||
|
quant_max_bound: float = 0.0,
|
||||||
|
quant_min_bound: float = 0.0,
|
||||||
|
out_linear_in_scale: float = -1.0,
|
||||||
|
encoder_block_shape_q: int = 64,
|
||||||
|
decoder_block_shape_q: int = 16,
|
||||||
|
max_partition_size: int = 32768,
|
||||||
|
encoder_max_partition_size: int = 32768,
|
||||||
|
speculate_max_draft_token_num: int = 1,
|
||||||
|
causal: bool = True,
|
||||||
|
speculate_decoder: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
append_attention
|
||||||
|
"""
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
append_attention_with_output_gpu(
|
||||||
|
qkv,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
block_tables,
|
||||||
|
encoder_batch_ids,
|
||||||
|
encoder_tile_ids_per_batch,
|
||||||
|
encoder_num_blocks,
|
||||||
|
kv_batch_ids,
|
||||||
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks,
|
||||||
|
decoder_batch_ids,
|
||||||
|
decoder_tile_ids_per_batch,
|
||||||
|
decoder_num_blocks,
|
||||||
|
set_max_lengths,
|
||||||
|
max_len_kv,
|
||||||
|
out,
|
||||||
|
rotary_embs,
|
||||||
|
attn_mask,
|
||||||
|
qkv_bias,
|
||||||
|
qkv_scale,
|
||||||
|
k_quant_scale,
|
||||||
|
v_quant_scale,
|
||||||
|
k_dequant_scale,
|
||||||
|
v_dequant_scale,
|
||||||
|
cache_k_zp,
|
||||||
|
cache_v_zp,
|
||||||
|
linear_shift,
|
||||||
|
linear_smooth,
|
||||||
|
mask_offset,
|
||||||
|
kv_signal_data,
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps,
|
||||||
|
compute_type,
|
||||||
|
cache_quant_type,
|
||||||
|
use_neox_rotary_style,
|
||||||
|
rope_3d,
|
||||||
|
max_input_length,
|
||||||
|
quant_max_bound,
|
||||||
|
quant_min_bound,
|
||||||
|
out_linear_in_scale,
|
||||||
|
encoder_block_shape_q,
|
||||||
|
decoder_block_shape_q,
|
||||||
|
max_partition_size,
|
||||||
|
encoder_max_partition_size,
|
||||||
|
speculate_max_draft_token_num,
|
||||||
|
causal,
|
||||||
|
speculate_decoder,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -532,7 +532,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
|
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
|
||||||
True, # causal
|
True, # causal
|
||||||
False, # speculate_decoder
|
False, # speculate_decoder
|
||||||
)[0]
|
)
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
|
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
|
||||||
|
639
tests/layers/test_append_attention_with_output.py
Normal file
639
tests/layers/test_append_attention_with_output.py
Normal file
@@ -0,0 +1,639 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle.incubate.nn.functional import fused_rms_norm
|
||||||
|
|
||||||
|
paddle.seed(10)
|
||||||
|
|
||||||
|
|
||||||
|
class RopeEmbedding:
|
||||||
|
def __init__(self, use_neox_rotary_style=False):
|
||||||
|
self.use_neox_rotary_style = use_neox_rotary_style
|
||||||
|
self.base = 10000
|
||||||
|
|
||||||
|
def get_neox_style_position_embedding(self, position_ids, head_dim):
|
||||||
|
bsz, max_seq_len = position_ids.shape[:2]
|
||||||
|
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
|
||||||
|
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
|
||||||
|
|
||||||
|
# shape: [B, S, D/2]
|
||||||
|
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
|
||||||
|
# shape: [B, S, 1, D]
|
||||||
|
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim))
|
||||||
|
|
||||||
|
rot_emb[0] = paddle.cos(emb)
|
||||||
|
rot_emb[1] = paddle.sin(emb)
|
||||||
|
return rot_emb
|
||||||
|
|
||||||
|
def get_rotary_position_embedding(self, position_ids, head_dim):
|
||||||
|
bsz, max_seq_len = position_ids.shape[:2]
|
||||||
|
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32")
|
||||||
|
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
|
||||||
|
|
||||||
|
# shape: [B, S, D/2]
|
||||||
|
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
|
||||||
|
# shape: [B, S, D/2]
|
||||||
|
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2))
|
||||||
|
# shape: [B, S, 1, D]
|
||||||
|
emb = paddle.unsqueeze(emb, 2)
|
||||||
|
|
||||||
|
rot_emb[0] = paddle.cos(emb)
|
||||||
|
rot_emb[1] = paddle.sin(emb)
|
||||||
|
return rot_emb
|
||||||
|
|
||||||
|
def _apply_rope(self, rotary_emb, q, k, v=None, causal=False):
|
||||||
|
# sin [sequence_length, embed_size_per_head//2]
|
||||||
|
# cos [sequence_length, embed_size_per_head//2]
|
||||||
|
# sin, cos = paddle.chunk(rp, 2, axis=-1)
|
||||||
|
seq, head_dim = q.shape[2], q.shape[3]
|
||||||
|
cos, sin = paddle.chunk(rotary_emb, 2, axis=0)
|
||||||
|
cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
|
||||||
|
sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
|
||||||
|
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
||||||
|
|
||||||
|
if self.use_neox_rotary_style:
|
||||||
|
sin_pos = sin
|
||||||
|
cos_pos = cos
|
||||||
|
# NeoX Stype:前后半部分分块旋转
|
||||||
|
rotate_half_q = paddle.reshape(
|
||||||
|
paddle.stack(
|
||||||
|
[
|
||||||
|
-q[:, :, :, q.shape[-1] // 2 :],
|
||||||
|
q[:, :, :, : q.shape[-1] // 2],
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
paddle.shape(q),
|
||||||
|
)
|
||||||
|
rotate_half_k = paddle.reshape(
|
||||||
|
paddle.stack(
|
||||||
|
[
|
||||||
|
-k[:, :, :, k.shape[-1] // 2 :],
|
||||||
|
k[:, :, :, : k.shape[-1] // 2],
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
paddle.shape(k),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim])
|
||||||
|
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
||||||
|
cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim])
|
||||||
|
# GPT Stype:奇偶位置分块旋转
|
||||||
|
rotate_half_q = paddle.reshape(
|
||||||
|
paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1),
|
||||||
|
paddle.shape(q),
|
||||||
|
)
|
||||||
|
rotate_half_k = paddle.reshape(
|
||||||
|
paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1),
|
||||||
|
paddle.shape(k),
|
||||||
|
)
|
||||||
|
|
||||||
|
query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos))
|
||||||
|
|
||||||
|
key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos))
|
||||||
|
|
||||||
|
return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def create_attn_mask(
|
||||||
|
mask_type,
|
||||||
|
batch_size,
|
||||||
|
seq_lens,
|
||||||
|
pre_cache_length=0,
|
||||||
|
):
|
||||||
|
max_seq_len = max(seq_lens)
|
||||||
|
mask = paddle.zeros(
|
||||||
|
# [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length],
|
||||||
|
[batch_size, 1, max_seq_len, max_seq_len],
|
||||||
|
dtype=mask_type,
|
||||||
|
)
|
||||||
|
mask[:, :, :, :pre_cache_length] = 1
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
mask[i, 0, :seq_len, :seq_len] = (
|
||||||
|
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) - 1
|
||||||
|
) * 1e4
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len):
|
||||||
|
_, num_head, blocksize, dim_head = cache_k.shape
|
||||||
|
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
|
||||||
|
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
|
||||||
|
for i in range(bsz):
|
||||||
|
for j in range(cache_seq_len):
|
||||||
|
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||||
|
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||||
|
return out_cache_k, out_cache_v
|
||||||
|
|
||||||
|
|
||||||
|
def naive_attention_impl(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
cache_k=None,
|
||||||
|
cache_v=None,
|
||||||
|
pre_cache_k=None,
|
||||||
|
pre_cache_v=None,
|
||||||
|
mask=None,
|
||||||
|
scale=1.0,
|
||||||
|
cache_k_dequant_scales=None,
|
||||||
|
cache_v_dequant_scales=None,
|
||||||
|
use_cachekv_int8="None",
|
||||||
|
q_norm_weight=None,
|
||||||
|
k_norm_weight=None,
|
||||||
|
):
|
||||||
|
batch = query.shape[0]
|
||||||
|
heads = query.shape[1]
|
||||||
|
seq_len = query.shape[2]
|
||||||
|
head_dim = query.shape[3]
|
||||||
|
kv_head = key.shape[1]
|
||||||
|
|
||||||
|
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
|
||||||
|
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
|
||||||
|
key = key.reshape([batch, heads, seq_len, head_dim])
|
||||||
|
|
||||||
|
if cache_k is not None:
|
||||||
|
cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
|
||||||
|
cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1])
|
||||||
|
cache_k = cache_k.reshape([batch, heads, -1, head_dim])
|
||||||
|
key = paddle.concat([cache_k, key], axis=2)
|
||||||
|
|
||||||
|
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
|
||||||
|
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
|
||||||
|
value = value.reshape([batch, heads, seq_len, head_dim])
|
||||||
|
|
||||||
|
if cache_v is not None:
|
||||||
|
cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
|
||||||
|
cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1])
|
||||||
|
cache_v = cache_v.reshape([batch, heads, -1, head_dim])
|
||||||
|
value = paddle.concat([cache_v, value], axis=2)
|
||||||
|
|
||||||
|
qk_res = paddle.matmul(query, key, transpose_y=True)
|
||||||
|
attention = qk_res * scale
|
||||||
|
if mask is not None:
|
||||||
|
attention = attention + mask
|
||||||
|
softmax_result = paddle.nn.functional.softmax(attention, -1)
|
||||||
|
result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
|
||||||
|
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time)
|
||||||
|
cum_offsets = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||||
|
cum_offsets[1:] = cum_offsets_now
|
||||||
|
token_num = paddle.sum(seq_lens_this_time)
|
||||||
|
padding_offsets = paddle.zeros(shape=(token_num), dtype="int32")
|
||||||
|
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||||
|
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||||
|
for i in range(bsz):
|
||||||
|
seq_len_now = seq_lens_this_time[i]
|
||||||
|
cum_offset = cum_offsets[i]
|
||||||
|
for j in range(seq_len_now):
|
||||||
|
padding_offsets[i * max_seq_len - cum_offset + j] = cum_offset
|
||||||
|
cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i + 1]
|
||||||
|
cu_seqlens_q[i + 1] = cum_seq_len
|
||||||
|
cu_seqlens_k[i + 1] = cum_seq_len
|
||||||
|
return padding_offsets, cum_offsets[:-1], cu_seqlens_q, cu_seqlens_k
|
||||||
|
|
||||||
|
|
||||||
|
def remove_padding(seq_lens, cu_seq_lens, inputs, token_num):
|
||||||
|
bsz, num_head, seq_len, dim_head = inputs.shape
|
||||||
|
output = paddle.zeros(shape=[token_num, num_head * dim_head], dtype=inputs.dtype)
|
||||||
|
inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1])
|
||||||
|
for i in range(bsz):
|
||||||
|
seq_len_now = seq_lens[i]
|
||||||
|
start_idx = cu_seq_lens[i]
|
||||||
|
end_idx = cu_seq_lens[i + 1]
|
||||||
|
output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head, place, dtype):
|
||||||
|
query = np.random.random([bs, q_num_head, seq_len, dim_head]) / 10
|
||||||
|
q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False)
|
||||||
|
key = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
|
||||||
|
k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False)
|
||||||
|
value = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
|
||||||
|
v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False)
|
||||||
|
token_num = bs * seq_len
|
||||||
|
|
||||||
|
qkv = paddle.concat(
|
||||||
|
[
|
||||||
|
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
|
||||||
|
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
||||||
|
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
).reshape([token_num, -1])
|
||||||
|
return q, k, v, qkv
|
||||||
|
|
||||||
|
|
||||||
|
def apply_qk_norm(head_dim, dtype, q, k):
|
||||||
|
q_norm_weight = np.random.random([head_dim]) / 10
|
||||||
|
k_norm_weight = np.random.random([head_dim]) / 10
|
||||||
|
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
|
||||||
|
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
|
||||||
|
print("q:", q.shape)
|
||||||
|
print("k:", k.shape)
|
||||||
|
bs, q_num_head, seq_len, dim_head = q.shape
|
||||||
|
_, kv_num_head, _, _ = k.shape
|
||||||
|
|
||||||
|
q = q.reshape([-1, head_dim])
|
||||||
|
k = k.reshape([-1, head_dim])
|
||||||
|
print("q:", q)
|
||||||
|
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
|
||||||
|
print("q after norm:", q)
|
||||||
|
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
|
||||||
|
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||||||
|
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||||||
|
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def split_query_by_phase(
|
||||||
|
query,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
q_dim,
|
||||||
|
k_dim,
|
||||||
|
v_dim,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
将 query 拆分为 encoder 和 decoder 的 Q/K/V。
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch = seq_lens_encoder.shape[0]
|
||||||
|
max_seq = query.shape[0] // batch
|
||||||
|
|
||||||
|
# 还原 query 为 [batch, seq, dim]
|
||||||
|
total_dim = q_dim + k_dim + v_dim
|
||||||
|
query = paddle.reshape(query, [batch, max_seq, total_dim])
|
||||||
|
|
||||||
|
# 计算 mask,表示该 batch 是否是 encoder/decoder
|
||||||
|
is_encoder = (seq_lens_encoder > 0).astype("bool").reshape([-1]) # [batch]
|
||||||
|
is_decoder = (seq_lens_decoder > 0).astype("bool").reshape([-1]) # [batch]
|
||||||
|
|
||||||
|
# 准备输出列表
|
||||||
|
enc_qs, enc_ks, enc_vs = [], [], []
|
||||||
|
dec_qs, dec_ks, dec_vs = [], [], []
|
||||||
|
|
||||||
|
for i in range(batch):
|
||||||
|
real_len = int(seq_lens_this_time[i]) # 当前 batch 的有效长度
|
||||||
|
cur_query = query[i, :real_len, :] # [seq_i, q+k+v]
|
||||||
|
|
||||||
|
q, k, v = paddle.split(cur_query, [q_dim, k_dim, v_dim], axis=-1)
|
||||||
|
|
||||||
|
if is_encoder[i]:
|
||||||
|
enc_qs.append(q)
|
||||||
|
enc_ks.append(k)
|
||||||
|
enc_vs.append(v)
|
||||||
|
elif is_decoder[i]:
|
||||||
|
dec_qs.append(q)
|
||||||
|
dec_ks.append(k)
|
||||||
|
dec_vs.append(v)
|
||||||
|
|
||||||
|
if enc_qs:
|
||||||
|
enc_q = paddle.concat(enc_qs, axis=0)
|
||||||
|
enc_k = paddle.concat(enc_ks, axis=0)
|
||||||
|
enc_v = paddle.concat(enc_vs, axis=0)
|
||||||
|
else:
|
||||||
|
enc_q = enc_k = enc_v = paddle.zeros([0, q_dim], dtype=query.dtype)
|
||||||
|
|
||||||
|
if dec_qs:
|
||||||
|
dec_q = paddle.concat(dec_qs, axis=0)
|
||||||
|
dec_k = paddle.concat(dec_ks, axis=0)
|
||||||
|
dec_v = paddle.concat(dec_vs, axis=0)
|
||||||
|
else:
|
||||||
|
dec_q = dec_k = dec_v = paddle.zeros([0, q_dim], dtype=query.dtype)
|
||||||
|
|
||||||
|
return (enc_q, enc_k, enc_v), (dec_q, dec_k, dec_v)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
paddle.disable_static()
|
||||||
|
self.name = "TestAppendGroupQueryAttnWithRope"
|
||||||
|
self.place = paddle.CUDAPlace(0)
|
||||||
|
self.batch_size = 1
|
||||||
|
self.q_num_head = 12
|
||||||
|
self.kv_num_head = 2
|
||||||
|
self.seq_len = 64
|
||||||
|
self.max_dec_len = 64
|
||||||
|
self.dim_head = 128
|
||||||
|
self.q_hid_dim = self.q_num_head * self.dim_head
|
||||||
|
self.kv_hid_dim = self.kv_num_head * self.dim_head
|
||||||
|
self.blocksize = 64
|
||||||
|
self.use_neox_rotary_style = False
|
||||||
|
# max_seq_len = self.seq_len + self.max_dec_len
|
||||||
|
self.max_seq_len = self.seq_len + self.max_dec_len
|
||||||
|
self.softmax_scale = self.dim_head**-0.5
|
||||||
|
self.rope_theta = 10000
|
||||||
|
self.dtype = "float16"
|
||||||
|
self.use_qk_norm = True
|
||||||
|
self.use_mask_offset = False
|
||||||
|
self.init_tensor()
|
||||||
|
|
||||||
|
def init_tensor(self):
|
||||||
|
self.block_num_per_seq = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
|
||||||
|
self.rope = RopeEmbedding(self.use_neox_rotary_style)
|
||||||
|
self.max_block_num = self.block_num_per_seq * self.batch_size
|
||||||
|
self.free_list = list(range(self.max_block_num - 1, -1, -1))
|
||||||
|
|
||||||
|
self.seq_lens_enc = [
|
||||||
|
self.seq_len,
|
||||||
|
] * self.batch_size
|
||||||
|
self.seq_lens_dec = [
|
||||||
|
0,
|
||||||
|
] * self.batch_size
|
||||||
|
self.max_enc_len_this_time = max(self.seq_lens_enc)
|
||||||
|
self.max_dec_len_this_time = max(self.seq_lens_dec)
|
||||||
|
self.seq_lens_encoder = paddle.to_tensor(
|
||||||
|
self.seq_lens_enc,
|
||||||
|
"int32",
|
||||||
|
)
|
||||||
|
self.seq_lens_decoder = paddle.to_tensor(
|
||||||
|
self.seq_lens_dec,
|
||||||
|
"int32",
|
||||||
|
)
|
||||||
|
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
|
||||||
|
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
|
||||||
|
self.seq_lens_this_time = self.seq_lens_encoder
|
||||||
|
|
||||||
|
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
|
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
|
self.cache_shape = (
|
||||||
|
self.max_block_num,
|
||||||
|
self.kv_num_head,
|
||||||
|
self.blocksize,
|
||||||
|
self.dim_head,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale = 1.0 / np.sqrt(self.dim_head)
|
||||||
|
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||||||
|
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||||||
|
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
|
||||||
|
for j in range(need_block_num):
|
||||||
|
self.block_tables[i, j] = self.free_list.pop()
|
||||||
|
(
|
||||||
|
self.padding_offset,
|
||||||
|
self.cum_offset,
|
||||||
|
self.cu_seqlens_q,
|
||||||
|
self.cu_seqlens_k,
|
||||||
|
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
|
||||||
|
self.token_num = self.padding_offset.shape[0]
|
||||||
|
self.mask_offset = None
|
||||||
|
if self.use_mask_offset:
|
||||||
|
self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32")
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
for j in range(self.seq_len):
|
||||||
|
self.mask_offset[i * self.seq_len + j] = j
|
||||||
|
|
||||||
|
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
|
||||||
|
paddle.disable_static()
|
||||||
|
self.token_num = self.seq_len * self.batch_size
|
||||||
|
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
|
||||||
|
self.batch_size,
|
||||||
|
self.q_num_head,
|
||||||
|
self.kv_num_head,
|
||||||
|
self.seq_len,
|
||||||
|
self.dim_head,
|
||||||
|
self.place,
|
||||||
|
self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
|
||||||
|
else:
|
||||||
|
q_norm_weight = None
|
||||||
|
k_norm_weight = None
|
||||||
|
out_ = naive_attention_impl(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
naive_cache_k,
|
||||||
|
naive_cache_v,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
attn_mask,
|
||||||
|
self.scale,
|
||||||
|
)
|
||||||
|
out_ = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num)
|
||||||
|
speculate_max_draft_token_num = 1
|
||||||
|
from fastdeploy.model_executor.layers.attention.ops import (
|
||||||
|
append_attention_with_output,
|
||||||
|
get_block_shape_and_split_kv_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
encoder_batch_ids,
|
||||||
|
encoder_tile_ids_per_batch,
|
||||||
|
encoder_num_blocks,
|
||||||
|
kv_batch_ids,
|
||||||
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks,
|
||||||
|
max_len_kv,
|
||||||
|
) = get_block_shape_and_split_kv_block(
|
||||||
|
self.seq_lens_encoder,
|
||||||
|
self.seq_lens_decoder,
|
||||||
|
self.seq_lens_this_time,
|
||||||
|
self.decoder_batch_ids,
|
||||||
|
self.decoder_tile_ids_per_batch,
|
||||||
|
self.decoder_num_blocks_cpu,
|
||||||
|
self.max_len_tensor_cpu,
|
||||||
|
64,
|
||||||
|
12,
|
||||||
|
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||||
|
self.blocksize,
|
||||||
|
speculate_max_draft_token_num + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
WARM_UP = 1
|
||||||
|
RUN_TIME = 2
|
||||||
|
out = paddle.zeros((qkv.shape[0], self.q_hid_dim), dtype=q.dtype).to(q.place)
|
||||||
|
for i in range(WARM_UP + RUN_TIME):
|
||||||
|
if i == WARM_UP:
|
||||||
|
paddle.device.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
append_attention_with_output(
|
||||||
|
qkv,
|
||||||
|
self.cache_k,
|
||||||
|
self.cache_v,
|
||||||
|
self.seq_lens_encoder,
|
||||||
|
self.seq_lens_decoder,
|
||||||
|
self.seq_lens_this_time,
|
||||||
|
self.padding_offset,
|
||||||
|
self.cum_offset,
|
||||||
|
self.block_tables,
|
||||||
|
encoder_batch_ids,
|
||||||
|
encoder_tile_ids_per_batch,
|
||||||
|
encoder_num_blocks,
|
||||||
|
kv_batch_ids,
|
||||||
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks,
|
||||||
|
self.decoder_batch_ids,
|
||||||
|
self.decoder_tile_ids_per_batch,
|
||||||
|
self.decoder_num_blocks_cpu,
|
||||||
|
self.max_len_tensor_cpu,
|
||||||
|
max_len_kv,
|
||||||
|
out,
|
||||||
|
self.rope_emb, # rope_emb
|
||||||
|
None, # attn_mask
|
||||||
|
None, # qkv_bias
|
||||||
|
None, # qkv_out_scales
|
||||||
|
None, # cache_k_quant_scales
|
||||||
|
None, # cache_v_quant_scales
|
||||||
|
None, # cache_k_dequant_scales
|
||||||
|
None, # cache_v_dequant_scales
|
||||||
|
None, # cache_k_zp
|
||||||
|
None, # cache_v_zp
|
||||||
|
None, # linear_shift
|
||||||
|
None, # linear_smooth
|
||||||
|
self.mask_offset, # mask_offset
|
||||||
|
None, # kv_signal_data
|
||||||
|
q_norm_weight, # q_norm_weight
|
||||||
|
k_norm_weight, # k_norm_weight
|
||||||
|
1e-6,
|
||||||
|
"fp16",
|
||||||
|
"none", # cache_quant_type
|
||||||
|
self.use_neox_rotary_style,
|
||||||
|
False,
|
||||||
|
self.max_seq_len,
|
||||||
|
0.0, # quant_min_bound
|
||||||
|
0.0, # quant_max_bound
|
||||||
|
-1, # out_linear_in_scale
|
||||||
|
64, # encoder_block_shape_q
|
||||||
|
16, # decoder_block_shape_q
|
||||||
|
32768, # max_partition_size
|
||||||
|
32768, # encoder_max_partition_size
|
||||||
|
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
|
||||||
|
True, # causal
|
||||||
|
False, # speculate_decoder
|
||||||
|
)
|
||||||
|
paddle.device.synchronize()
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
|
||||||
|
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||||||
|
self.cache_k,
|
||||||
|
self.cache_v,
|
||||||
|
self.batch_size,
|
||||||
|
self.block_tables,
|
||||||
|
self.seq_len,
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
out.numpy(),
|
||||||
|
out_.numpy(),
|
||||||
|
rtol=1e-02,
|
||||||
|
atol=1e-02,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_all(self):
|
||||||
|
tmp_position_ids = paddle.arange(self.seq_len + self.max_dec_len).reshape((1, -1))
|
||||||
|
# appendattn 传的是最大maxseq
|
||||||
|
if self.use_neox_rotary_style:
|
||||||
|
self.rope_emb = self.rope.get_neox_style_position_embedding(tmp_position_ids, self.dim_head)
|
||||||
|
else:
|
||||||
|
self.rope_emb = self.rope.get_rotary_position_embedding(tmp_position_ids, self.dim_head)
|
||||||
|
self.attention_mask = create_attn_mask(
|
||||||
|
self.dtype,
|
||||||
|
self.batch_size,
|
||||||
|
[
|
||||||
|
self.seq_len,
|
||||||
|
]
|
||||||
|
* self.batch_size,
|
||||||
|
)
|
||||||
|
# encoder
|
||||||
|
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
|
||||||
|
self.seq_lens_this_time = self.seq_lens_encoder
|
||||||
|
if self.use_mask_offset:
|
||||||
|
print("encoder mask_offset: ", self.mask_offset)
|
||||||
|
self.cmp_append_attention(attn_mask=self.attention_mask)
|
||||||
|
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||||||
|
self.cache_k,
|
||||||
|
self.cache_v,
|
||||||
|
self.batch_size,
|
||||||
|
self.block_tables,
|
||||||
|
self.seq_len,
|
||||||
|
)
|
||||||
|
# decoder
|
||||||
|
self.seq_lens_decoder[:] = self.seq_lens_encoder
|
||||||
|
self.seq_lens_encoder[:] = 0
|
||||||
|
self.seq_lens_this_time[:] = 1
|
||||||
|
self.seq_lens_enc = [
|
||||||
|
0,
|
||||||
|
] * self.batch_size
|
||||||
|
self.seq_lens_dec = [
|
||||||
|
self.seq_len,
|
||||||
|
] * self.batch_size
|
||||||
|
self.max_enc_len_this_time = max(self.seq_lens_enc)
|
||||||
|
self.max_dec_len_this_time = max(self.seq_lens_dec)
|
||||||
|
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
|
||||||
|
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
|
||||||
|
|
||||||
|
self.seq_len = 1
|
||||||
|
(
|
||||||
|
self.padding_offset,
|
||||||
|
self.cum_offset,
|
||||||
|
self.cu_seqlens_q,
|
||||||
|
self.cu_seqlens_k,
|
||||||
|
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
|
||||||
|
if self.use_mask_offset:
|
||||||
|
self.mask_offset = paddle.full(self.batch_size, 0, "int32")
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
self.mask_offset[i] = self.seq_lens_dec[i]
|
||||||
|
print("decoder mask_offset: ", self.mask_offset)
|
||||||
|
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
||||||
|
def setUp(self):
|
||||||
|
paddle.disable_static()
|
||||||
|
self.name = "TestAppendGroupQueryAttnWithRope"
|
||||||
|
self.place = paddle.CUDAPlace(0)
|
||||||
|
self.batch_size = 1
|
||||||
|
self.q_num_head = 12
|
||||||
|
self.kv_num_head = 2
|
||||||
|
self.seq_len = 64
|
||||||
|
self.max_dec_len = 64
|
||||||
|
self.dim_head = 128
|
||||||
|
self.q_hid_dim = self.q_num_head * self.dim_head
|
||||||
|
self.kv_hid_dim = self.kv_num_head * self.dim_head
|
||||||
|
self.blocksize = 64
|
||||||
|
self.use_neox_rotary_style = True
|
||||||
|
# max_seq_len = self.seq_len + self.max_dec_len
|
||||||
|
self.max_seq_len = self.seq_len + self.max_dec_len
|
||||||
|
self.softmax_scale = self.dim_head**-0.5
|
||||||
|
self.rope_theta = 10000
|
||||||
|
self.dtype = "float16"
|
||||||
|
self.use_qk_norm = False
|
||||||
|
self.use_mask_offset = True
|
||||||
|
self.init_tensor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user