mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others]get_block_shape_and_split_kv_block clean code (#5123)
This commit is contained in:
@@ -14,8 +14,8 @@
|
||||
|
||||
#include "append_attn/append_attention_kernel.h"
|
||||
#include "append_attn/decoder_write_cache_with_rope_kernel.h"
|
||||
#include "append_attn/speculate_write_cache_with_rope_kernel.h"
|
||||
#include "append_attn/encoder_write_cache_with_rope_kernel.h"
|
||||
#include "append_attn/speculate_write_cache_with_rope_kernel.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -26,17 +26,16 @@ class type2value;
|
||||
|
||||
template <>
|
||||
class type2value<phi::dtype::bfloat16> {
|
||||
public:
|
||||
static constexpr paddle::DataType value = paddle::DataType::BFLOAT16;
|
||||
public:
|
||||
static constexpr paddle::DataType value = paddle::DataType::BFLOAT16;
|
||||
};
|
||||
|
||||
template <>
|
||||
class type2value<phi::dtype::float16> {
|
||||
public:
|
||||
static constexpr paddle::DataType value = paddle::DataType::FLOAT16;
|
||||
public:
|
||||
static constexpr paddle::DataType value = paddle::DataType::FLOAT16;
|
||||
};
|
||||
|
||||
|
||||
template <paddle::DataType D>
|
||||
void AppendAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -96,14 +95,12 @@ void AppendAttentionKernel(
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
// set_max_lengths: max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time,
|
||||
// max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system
|
||||
int max_len_this_time = set_max_lengths.data<int>()[0];
|
||||
int max_enc_len_this_time =set_max_lengths.data<int>()[1];
|
||||
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
|
||||
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
|
||||
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
|
||||
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
|
||||
const int max_len_this_time = set_max_lengths.data<int>()[0];
|
||||
const int max_enc_len_this_time = set_max_lengths.data<int>()[1];
|
||||
const int max_dec_len_this_time = set_max_lengths.data<int>()[2];
|
||||
const int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
|
||||
const int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
|
||||
const int max_kv_len_this_time = set_max_lengths.data<int>()[5];
|
||||
|
||||
auto main_stream = qkv.stream();
|
||||
static cudaEvent_t main_event;
|
||||
@@ -125,54 +122,56 @@ void AppendAttentionKernel(
|
||||
qkv_out = qkv;
|
||||
}
|
||||
|
||||
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
|
||||
const paddle::Tensor& lambda_batch_ids,
|
||||
const paddle::Tensor& lambda_tile_ids_per_batch,
|
||||
const int lambda_num_blocks_data,
|
||||
const int lambda_block_shape_q,
|
||||
const int lambda_max_dec_len,
|
||||
const bool lambda_is_decoder,
|
||||
const bool lambda_enable_prefill,
|
||||
cudaStream_t& lambda_stream
|
||||
) -> void {
|
||||
CascadeAppendAttentionKernel<data_t, decltype(temp_args)>(
|
||||
meta_data,
|
||||
qkv_out,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
sinks,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
lambda_batch_ids,
|
||||
lambda_tile_ids_per_batch,
|
||||
cache_quant_type_str,
|
||||
lambda_num_blocks_data,
|
||||
lambda_block_shape_q,
|
||||
max_input_length,
|
||||
lambda_max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
lambda_is_decoder,
|
||||
lambda_enable_prefill,
|
||||
lambda_stream,
|
||||
&fmha_out,
|
||||
sliding_window);
|
||||
auto dispatch_CascadeAppendAttentionKernel =
|
||||
[&](auto temp_args,
|
||||
const paddle::Tensor& lambda_batch_ids,
|
||||
const paddle::Tensor& lambda_tile_ids_per_batch,
|
||||
const int lambda_num_blocks_data,
|
||||
const int lambda_block_shape_q,
|
||||
const int lambda_max_dec_len,
|
||||
const bool lambda_is_decoder,
|
||||
const bool lambda_enable_prefill,
|
||||
cudaStream_t& lambda_stream) -> void {
|
||||
CascadeAppendAttentionKernel<data_t, decltype(temp_args)>(
|
||||
meta_data,
|
||||
qkv_out,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales
|
||||
: cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales
|
||||
: cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
sinks,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
lambda_batch_ids,
|
||||
lambda_tile_ids_per_batch,
|
||||
cache_quant_type_str,
|
||||
lambda_num_blocks_data,
|
||||
lambda_block_shape_q,
|
||||
max_input_length,
|
||||
lambda_max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
lambda_is_decoder,
|
||||
lambda_enable_prefill,
|
||||
lambda_stream,
|
||||
&fmha_out,
|
||||
sliding_window);
|
||||
};
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
@@ -182,8 +181,9 @@ void AppendAttentionKernel(
|
||||
int encoder_num_blocks_data = encoder_num_blocks.data<int>()[0];
|
||||
int kv_num_blocks_data = kv_num_blocks.data<int>()[0];
|
||||
|
||||
auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void {
|
||||
EncoderWriteCacheWithRopeKernel<data_t, decltype(temp_args)>(
|
||||
auto dispatch_EncoderWriteCacheWithRopeKernel =
|
||||
[&](auto temp_args) -> void {
|
||||
EncoderWriteCacheWithRopeKernel<data_t, decltype(temp_args)>(
|
||||
meta_data,
|
||||
qkv,
|
||||
seq_lens_this_time,
|
||||
@@ -225,24 +225,50 @@ void AppendAttentionKernel(
|
||||
}
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
switch (fmha_out.dtype()) {
|
||||
case paddle::DataType::INT8:{
|
||||
case paddle::DataType::INT8: {
|
||||
int8_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream);
|
||||
dispatch_CascadeAppendAttentionKernel(tmp,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_data,
|
||||
encoder_block_shape_q,
|
||||
max_enc_dec_len_this_time,
|
||||
false,
|
||||
true,
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT8_E4M3FN:{
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
phi::dtype::float8_e4m3fn tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream);
|
||||
dispatch_CascadeAppendAttentionKernel(tmp,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_data,
|
||||
encoder_block_shape_q,
|
||||
max_enc_dec_len_this_time,
|
||||
false,
|
||||
true,
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
default:{
|
||||
PD_THROW("Only supported output fmha_out of quant dtype in ['int8', 'FLOAT8_E4M3FN'].");
|
||||
default: {
|
||||
PD_THROW(
|
||||
"Only supported output fmha_out of quant dtype in ['int8', "
|
||||
"'FLOAT8_E4M3FN'].");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
data_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream);
|
||||
dispatch_CascadeAppendAttentionKernel(tmp,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_data,
|
||||
encoder_block_shape_q,
|
||||
max_enc_dec_len_this_time,
|
||||
false,
|
||||
true,
|
||||
main_stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,23 +396,44 @@ void AppendAttentionKernel(
|
||||
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
switch (fmha_out.dtype()) {
|
||||
case paddle::DataType::INT8:{
|
||||
int8_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
case paddle::DataType::INT8: {
|
||||
int8_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_data,
|
||||
decoder_block_shape_q,
|
||||
max_kv_len_this_time,
|
||||
!speculate_decoder,
|
||||
!speculate_decoder,
|
||||
exec_stream);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT8_E4M3FN:{
|
||||
phi::dtype::float8_e4m3fn tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
phi::dtype::float8_e4m3fn tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_data,
|
||||
decoder_block_shape_q,
|
||||
max_kv_len_this_time,
|
||||
!speculate_decoder,
|
||||
!speculate_decoder,
|
||||
exec_stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
data_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
data_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_data,
|
||||
decoder_block_shape_q,
|
||||
max_kv_len_this_time,
|
||||
!speculate_decoder,
|
||||
!speculate_decoder,
|
||||
exec_stream);
|
||||
}
|
||||
if (max_enc_len_this_time > 0) {
|
||||
cudaEventRecord(decoder_event, exec_stream);
|
||||
@@ -471,8 +518,14 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
// template dtype generation
|
||||
phi::DataType dtype_id;
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
|
||||
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
dtype_id = phi::DataType::FLOAT16;
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
dtype_id = phi::DataType::BFLOAT16;
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
dtype_id = phi::DataType::BFLOAT16;
|
||||
@@ -498,15 +551,15 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = paddle::zeros(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = paddle::zeros(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
} else{
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
} else {
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
@@ -521,79 +574,78 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window);
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window);
|
||||
};
|
||||
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
switch (dtype_id){
|
||||
case phi::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
case phi::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
default:
|
||||
PD_THROW(
|
||||
switch (dtype_id) {
|
||||
case phi::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
case phi::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
default:
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
break;
|
||||
}
|
||||
|
||||
return {paddle::Tensor{}};
|
||||
@@ -678,60 +730,60 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window);
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window);
|
||||
};
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
@@ -769,7 +821,6 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
return {fmha_out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
@@ -895,8 +946,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
return {paddle::DataType::INT8};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
} else {
|
||||
PD_THROW(
|
||||
"Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
@@ -907,8 +959,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
return {paddle::DataType::INT8};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
} else {
|
||||
PD_THROW(
|
||||
"Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::FLOAT16};
|
||||
@@ -1034,8 +1087,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
return {fmha_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
@@ -1074,24 +1125,25 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
paddle::Optional("k_norm_weight"),
|
||||
paddle::Optional("sinks")})
|
||||
.Outputs({"fmha_out"})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
"sliding_window: int",
|
||||
})
|
||||
.Attrs({
|
||||
"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
"sliding_window: int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||
@@ -1136,24 +1188,25 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
paddle::Optional("sinks")})
|
||||
.Outputs({"fmha_out_out"})
|
||||
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
"sliding_window: int",
|
||||
})
|
||||
.Attrs({
|
||||
"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
"sliding_window: int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));
|
||||
|
||||
@@ -79,7 +79,7 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder,
|
||||
max_lens[2] = total_max_len_decoder;
|
||||
max_lens[3] = total;
|
||||
max_lens[4] = total_just_dec;
|
||||
max_lens[8] = total_max_len_kv;
|
||||
max_lens[5] = total_max_len_kv;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const int block_size) {
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
@@ -302,10 +301,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||
int max_enc_dec_len_this_time = max_len_cpu_ptr[3];
|
||||
int max_just_dec_len_this_time = max_len_cpu_ptr[4];
|
||||
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
int max_kv_len_this_time = max_len_cpu_ptr[8];
|
||||
int max_kv_len_this_time = max_len_cpu_ptr[5];
|
||||
|
||||
const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0];
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
@@ -343,25 +341,15 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where *
|
||||
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
|
||||
// seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
@@ -374,22 +362,15 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
|
||||
// should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
@@ -413,13 +394,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
@@ -486,8 +460,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const int block_size) {
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -498,8 +471,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const int block_size) {
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -527,8 +499,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Attrs({"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
"block_size: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
|
||||
@@ -381,8 +381,7 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num);
|
||||
const int block_size);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& token_num,
|
||||
|
||||
@@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -101,7 +98,6 @@ def allocate_launch_related_buffer(
|
||||
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -175,10 +171,6 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
@@ -263,6 +255,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
cache_v_scales = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
if layer.layer_id == 0:
|
||||
# print(forward_meta.seq_lens_this_time)
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
@@ -283,7 +276,6 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
if self.use_output:
|
||||
@@ -330,7 +322,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
forward_meta.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
@@ -342,8 +334,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
res,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
cache_k_scales,
|
||||
@@ -387,7 +379,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
forward_meta.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
@@ -398,8 +390,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
cache_k_scales,
|
||||
|
||||
@@ -213,7 +213,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
(
|
||||
|
||||
@@ -204,13 +204,12 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
|
||||
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
|
||||
@@ -44,7 +44,6 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_block_shape_q: int,
|
||||
group_size: int,
|
||||
block_size: int,
|
||||
decoder_step_token_num: int,
|
||||
):
|
||||
"""
|
||||
get_block_shape_and_split_kv_block
|
||||
@@ -70,7 +69,6 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_block_shape_q,
|
||||
group_size,
|
||||
block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -179,13 +179,12 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item()
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
|
||||
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
|
||||
@@ -628,7 +628,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
12,
|
||||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||
self.blocksize,
|
||||
speculate_max_draft_token_num + 1,
|
||||
)
|
||||
if self.use_dynamic_quant:
|
||||
cache_quant_type = "block_wise_fp8"
|
||||
|
||||
@@ -479,7 +479,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
12,
|
||||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||
self.blocksize,
|
||||
speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# Warm up
|
||||
|
||||
@@ -121,10 +121,10 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
"dtype": "bfloat16",
|
||||
"hidden_size": 4096,
|
||||
"max_position_embeddings": 131072,
|
||||
"max_model_len": 5500,
|
||||
"max_model_len": 36 * 1024 + 1024,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 4,
|
||||
"num_hidden_layers": 5,
|
||||
"num_hidden_layers": 57,
|
||||
}
|
||||
model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
|
||||
config_path = os.path.join(model_dir, "config.json")
|
||||
@@ -223,7 +223,7 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
max_model_len=fd_config.model_config.max_model_len,
|
||||
encoder_block_shape_q=64,
|
||||
decoder_block_shape_q=16,
|
||||
decoder_step_token_num=1,
|
||||
decoder_step_token_num=fd_config.speculative_config.num_speculative_tokens + 1,
|
||||
num_heads=fd_config.model_config.num_attention_heads,
|
||||
kv_num_heads=fd_config.model_config.num_key_value_heads,
|
||||
block_size=fd_config.cache_config.block_size,
|
||||
@@ -294,29 +294,30 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
def test_decode_performance_with_prefill(self):
|
||||
# Test parameters
|
||||
test_steps = 100
|
||||
prefill_batch_size = 1
|
||||
prefill_seq_len = 4096
|
||||
use_dynamic_quant = True
|
||||
act_tensor_dtype = paddle.bfloat16
|
||||
|
||||
prefill_hidden_states = paddle.randn(
|
||||
[prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size],
|
||||
dtype=act_tensor_dtype,
|
||||
)
|
||||
# prefill_batch_size = 1
|
||||
# prefill_seq_len = 4096
|
||||
|
||||
forward_meta = self.create_forward_meta(
|
||||
batch_size=prefill_batch_size,
|
||||
seq_len=prefill_seq_len,
|
||||
mode=ForwardMode.EXTEND,
|
||||
fd_config=self.fd_config,
|
||||
attn_backend=self.attn_backend,
|
||||
use_dynamic_quant=use_dynamic_quant,
|
||||
)
|
||||
# prefill_hidden_states = paddle.randn(
|
||||
# [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size],
|
||||
# dtype=act_tensor_dtype,
|
||||
# )
|
||||
|
||||
self.attn_backend.init_attention_metadata(forward_meta)
|
||||
self.attn_forward(forward_meta, prefill_hidden_states)
|
||||
# forward_meta = self.create_forward_meta(
|
||||
# batch_size=prefill_batch_size,
|
||||
# seq_len=prefill_seq_len,
|
||||
# mode=ForwardMode.EXTEND,
|
||||
# fd_config=self.fd_config,
|
||||
# attn_backend=self.attn_backend,
|
||||
# use_dynamic_quant=use_dynamic_quant,
|
||||
# )
|
||||
|
||||
paddle.device.synchronize()
|
||||
# self.attn_backend.init_attention_metadata(forward_meta)
|
||||
# self.attn_forward(forward_meta, prefill_hidden_states)
|
||||
|
||||
# paddle.device.synchronize()
|
||||
|
||||
# import paddle.profiler as profiler
|
||||
# p = profiler.Profiler(
|
||||
@@ -326,18 +327,18 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
# p.start()
|
||||
# p.step()
|
||||
|
||||
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||
for i in range(test_steps):
|
||||
start_events[i].record()
|
||||
# start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||
# end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||
# for i in range(test_steps):
|
||||
# start_events[i].record()
|
||||
|
||||
self.attn_forward(forward_meta, prefill_hidden_states)
|
||||
# self.attn_forward(forward_meta, prefill_hidden_states)
|
||||
|
||||
end_events[i].record()
|
||||
paddle.device.synchronize()
|
||||
# end_events[i].record()
|
||||
# paddle.device.synchronize()
|
||||
|
||||
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
|
||||
print(times[-5:])
|
||||
# times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
|
||||
# print(times[-5:])
|
||||
|
||||
# p.stop()
|
||||
|
||||
@@ -349,14 +350,14 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
# p.start()
|
||||
# p.step()
|
||||
|
||||
for decode_batch_size in [10, 20, 40, 60, 80, 100, 128]:
|
||||
for decode_batch_size in [32, 16, 8, 4, 2]:
|
||||
decode_hidden_states = paddle.randn(
|
||||
[decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype
|
||||
)
|
||||
|
||||
forward_meta = self.create_forward_meta(
|
||||
batch_size=decode_batch_size,
|
||||
seq_len=5000,
|
||||
seq_len=36 * 1024,
|
||||
mode=ForwardMode.DECODE,
|
||||
fd_config=self.fd_config,
|
||||
attn_backend=self.attn_backend,
|
||||
@@ -383,7 +384,6 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
start_events[i].record()
|
||||
|
||||
attn_cuda_graphs.replay()
|
||||
# self.attn_forward(forward_meta, decode_hidden_states)
|
||||
|
||||
end_events[i].record()
|
||||
paddle.device.synchronize()
|
||||
@@ -391,6 +391,8 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
|
||||
print(times[-5:])
|
||||
|
||||
del forward_meta
|
||||
|
||||
# p.stop()
|
||||
|
||||
|
||||
|
||||
@@ -254,7 +254,6 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_block_shape_q,
|
||||
self.num_q_head // self.num_kv_head,
|
||||
self.block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
s_time = 0
|
||||
for i in range(self.run_time + self.warm_up):
|
||||
|
||||
Reference in New Issue
Block a user