mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Add with_output version AppendAttention (#3302)
* get use_output from fd_config * add clear TODO description * add mask_offset para to align with develop * fix bug * fix use_output logic * fix sot bug
This commit is contained in:
@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
|
||||
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
void AppendAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
@@ -60,6 +60,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
@@ -122,27 +123,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
} else {
|
||||
qkv_out = qkv;
|
||||
}
|
||||
paddle::Tensor fmha_out;
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
D,
|
||||
qkv.place());
|
||||
}
|
||||
|
||||
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
|
||||
const paddle::Tensor& lambda_batch_ids,
|
||||
@@ -405,8 +385,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
cudaStreamWaitEvent(main_stream, decoder_event);
|
||||
}
|
||||
}
|
||||
|
||||
return {fmha_out, qkv_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> AppendAttention(
|
||||
@@ -481,12 +459,60 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
// template dtype generation
|
||||
phi::DataType dtype_id;
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
|
||||
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
dtype_id = phi::DataType::BFLOAT16;
|
||||
break;
|
||||
} else if (compute_dtype == "fp16") {
|
||||
dtype_id = phi::DataType::FLOAT16;
|
||||
break;
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
}
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// fmha_out generation, rewrite from AppendAttentionKernel
|
||||
paddle::Tensor fmha_out;
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
} else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
dtype_id,
|
||||
qkv.place());
|
||||
}
|
||||
|
||||
if (mask_offset) {
|
||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
@@ -508,6 +534,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
@@ -539,20 +566,183 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
};
|
||||
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
switch (dtype_id){
|
||||
case phi::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
case phi::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
default:
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
|
||||
return {paddle::Tensor{}};
|
||||
}
|
||||
|
||||
void AppendAttentionWithOutput(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
const auto& key_cache_dims = key_cache.dims();
|
||||
meta_data.token_nums = qkv_dims[0];
|
||||
meta_data.kv_num_heads = key_cache_dims[1];
|
||||
meta_data.head_dims = key_cache_dims[3];
|
||||
// TODO: trick method support c4, add attr head_dims in the future
|
||||
if (cache_quant_type_str == "cache_int4_zp") {
|
||||
meta_data.head_dims *= 2;
|
||||
}
|
||||
const int total_num_head =
|
||||
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
|
||||
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
if (mask_offset) {
|
||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
};
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
|
||||
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
|
||||
case paddle::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
return dispatch_by_template(bp16_dtype);
|
||||
dispatch_by_template(bp16_dtype);
|
||||
break;
|
||||
} else if (compute_dtype == "fp16") {
|
||||
return dispatch_by_template(fp16_dtype);
|
||||
dispatch_by_template(fp16_dtype);
|
||||
break;
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
@@ -565,9 +755,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {paddle::Tensor{}};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
@@ -629,7 +819,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
}
|
||||
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
|
||||
const int num_heads = total_num_head - 2 * kv_num_heads;
|
||||
return {{token_num, num_heads * head_dim}, qkv_shape};
|
||||
return {{token_num, num_heads * head_dim}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
@@ -688,32 +878,148 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
if (compute_dtype == "bf16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::INT8};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
}
|
||||
} else if (compute_dtype == "fp16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::INT8};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::FLOAT16};
|
||||
}
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& set_max_lengths_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const std::vector<int64_t>& fmha_out_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
return {fmha_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
const paddle::DataType& qkv_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& set_max_lengths_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::DataType& fmha_out_dtype,
|
||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
return {fmha_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
@@ -751,7 +1057,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
@@ -774,3 +1080,66 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"set_max_lengths",
|
||||
"max_len_kv",
|
||||
"fmha_out",
|
||||
paddle::Optional("rotary_embs"),
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("qkv_bias"),
|
||||
paddle::Optional("qkv_out_scales"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
paddle::Optional("cache_v_dequant_scales"),
|
||||
paddle::Optional("cache_k_zp"),
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths"),
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
|
||||
{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));
|
||||
|
||||
@@ -91,6 +91,49 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const int speculate_max_draft_token_num, const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
|
||||
const paddle::Tensor &encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &encoder_num_blocks,
|
||||
const paddle::Tensor &kv_batch_ids,
|
||||
const paddle::Tensor &kv_tile_ids_per_batch,
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
const paddle::optional<paddle::Tensor> &qkv_bias,
|
||||
const paddle::optional<paddle::Tensor> &qkv_out_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
const int max_input_length, const float quant_max_bound,
|
||||
const float quant_min_bound, const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int max_partition_size, const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num, const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
|
||||
@@ -881,6 +924,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* append_attention
|
||||
*/
|
||||
m.def("append_attention", &AppendAttention, "append attention function");
|
||||
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
|
||||
/**
|
||||
* gqa_rope_write_cache.cu
|
||||
* gqa_rope_write_cache
|
||||
|
||||
Reference in New Issue
Block a user