[Others]get_block_shape_and_split_kv_block clean code (#5123)

This commit is contained in:
周周周
2025-11-20 16:40:04 +08:00
committed by GitHub
parent af715db763
commit 6fa34102e8
12 changed files with 364 additions and 355 deletions

View File

@@ -14,8 +14,8 @@
#include "append_attn/append_attention_kernel.h"
#include "append_attn/decoder_write_cache_with_rope_kernel.h"
#include "append_attn/speculate_write_cache_with_rope_kernel.h"
#include "append_attn/encoder_write_cache_with_rope_kernel.h"
#include "append_attn/speculate_write_cache_with_rope_kernel.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -26,17 +26,16 @@ class type2value;
template <>
class type2value<phi::dtype::bfloat16> {
public:
static constexpr paddle::DataType value = paddle::DataType::BFLOAT16;
public:
static constexpr paddle::DataType value = paddle::DataType::BFLOAT16;
};
template <>
class type2value<phi::dtype::float16> {
public:
static constexpr paddle::DataType value = paddle::DataType::FLOAT16;
public:
static constexpr paddle::DataType value = paddle::DataType::FLOAT16;
};
template <paddle::DataType D>
void AppendAttentionKernel(
const AppendAttnMetaData& meta_data,
@@ -96,14 +95,12 @@ void AppendAttentionKernel(
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
// set_max_lengths: max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time,
// max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system
int max_len_this_time = set_max_lengths.data<int>()[0];
int max_enc_len_this_time =set_max_lengths.data<int>()[1];
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
const int max_len_this_time = set_max_lengths.data<int>()[0];
const int max_enc_len_this_time = set_max_lengths.data<int>()[1];
const int max_dec_len_this_time = set_max_lengths.data<int>()[2];
const int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
const int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
const int max_kv_len_this_time = set_max_lengths.data<int>()[5];
auto main_stream = qkv.stream();
static cudaEvent_t main_event;
@@ -125,54 +122,56 @@ void AppendAttentionKernel(
qkv_out = qkv;
}
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
const paddle::Tensor& lambda_batch_ids,
const paddle::Tensor& lambda_tile_ids_per_batch,
const int lambda_num_blocks_data,
const int lambda_block_shape_q,
const int lambda_max_dec_len,
const bool lambda_is_decoder,
const bool lambda_enable_prefill,
cudaStream_t& lambda_stream
) -> void {
CascadeAppendAttentionKernel<data_t, decltype(temp_args)>(
meta_data,
qkv_out,
key_cache,
value_cache,
attn_mask,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
sinks,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
lambda_batch_ids,
lambda_tile_ids_per_batch,
cache_quant_type_str,
lambda_num_blocks_data,
lambda_block_shape_q,
max_input_length,
lambda_max_dec_len,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
lambda_is_decoder,
lambda_enable_prefill,
lambda_stream,
&fmha_out,
sliding_window);
auto dispatch_CascadeAppendAttentionKernel =
[&](auto temp_args,
const paddle::Tensor& lambda_batch_ids,
const paddle::Tensor& lambda_tile_ids_per_batch,
const int lambda_num_blocks_data,
const int lambda_block_shape_q,
const int lambda_max_dec_len,
const bool lambda_is_decoder,
const bool lambda_enable_prefill,
cudaStream_t& lambda_stream) -> void {
CascadeAppendAttentionKernel<data_t, decltype(temp_args)>(
meta_data,
qkv_out,
key_cache,
value_cache,
attn_mask,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales
: cache_k_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales
: cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
sinks,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
lambda_batch_ids,
lambda_tile_ids_per_batch,
cache_quant_type_str,
lambda_num_blocks_data,
lambda_block_shape_q,
max_input_length,
lambda_max_dec_len,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
lambda_is_decoder,
lambda_enable_prefill,
lambda_stream,
&fmha_out,
sliding_window);
};
if (max_enc_len_this_time > 0) {
@@ -182,8 +181,9 @@ void AppendAttentionKernel(
int encoder_num_blocks_data = encoder_num_blocks.data<int>()[0];
int kv_num_blocks_data = kv_num_blocks.data<int>()[0];
auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void {
EncoderWriteCacheWithRopeKernel<data_t, decltype(temp_args)>(
auto dispatch_EncoderWriteCacheWithRopeKernel =
[&](auto temp_args) -> void {
EncoderWriteCacheWithRopeKernel<data_t, decltype(temp_args)>(
meta_data,
qkv,
seq_lens_this_time,
@@ -225,24 +225,50 @@ void AppendAttentionKernel(
}
if (out_linear_in_scale > 0.0) {
switch (fmha_out.dtype()) {
case paddle::DataType::INT8:{
case paddle::DataType::INT8: {
int8_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream);
dispatch_CascadeAppendAttentionKernel(tmp,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_data,
encoder_block_shape_q,
max_enc_dec_len_this_time,
false,
true,
main_stream);
break;
}
case paddle::DataType::FLOAT8_E4M3FN:{
case paddle::DataType::FLOAT8_E4M3FN: {
phi::dtype::float8_e4m3fn tmp;
dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream);
dispatch_CascadeAppendAttentionKernel(tmp,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_data,
encoder_block_shape_q,
max_enc_dec_len_this_time,
false,
true,
main_stream);
break;
}
default:{
PD_THROW("Only supported output fmha_out of quant dtype in ['int8', 'FLOAT8_E4M3FN'].");
default: {
PD_THROW(
"Only supported output fmha_out of quant dtype in ['int8', "
"'FLOAT8_E4M3FN'].");
break;
}
}
} else {
data_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream);
dispatch_CascadeAppendAttentionKernel(tmp,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_data,
encoder_block_shape_q,
max_enc_dec_len_this_time,
false,
true,
main_stream);
}
}
@@ -370,23 +396,44 @@ void AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
switch (fmha_out.dtype()) {
case paddle::DataType::INT8:{
int8_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
case paddle::DataType::INT8: {
int8_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_data,
decoder_block_shape_q,
max_kv_len_this_time,
!speculate_decoder,
!speculate_decoder,
exec_stream);
break;
}
case paddle::DataType::FLOAT8_E4M3FN:{
phi::dtype::float8_e4m3fn tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
case paddle::DataType::FLOAT8_E4M3FN: {
phi::dtype::float8_e4m3fn tmp;
dispatch_CascadeAppendAttentionKernel(tmp,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_data,
decoder_block_shape_q,
max_kv_len_this_time,
!speculate_decoder,
!speculate_decoder,
exec_stream);
break;
}
}
} else {
data_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
data_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_data,
decoder_block_shape_q,
max_kv_len_this_time,
!speculate_decoder,
!speculate_decoder,
exec_stream);
}
if (max_enc_len_this_time > 0) {
cudaEventRecord(decoder_event, exec_stream);
@@ -471,8 +518,14 @@ std::vector<paddle::Tensor> AppendAttention(
// template dtype generation
phi::DataType dtype_id;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
case paddle::DataType::FLOAT16: {
dtype_id = phi::DataType::FLOAT16;
break;
}
case paddle::DataType::BFLOAT16: {
dtype_id = phi::DataType::BFLOAT16;
break;
}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dtype_id = phi::DataType::BFLOAT16;
@@ -498,15 +551,15 @@ std::vector<paddle::Tensor> AppendAttention(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
} else{
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
} else {
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
@@ -521,79 +574,78 @@ std::vector<paddle::Tensor> AppendAttention(
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder,
sliding_window);
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder,
sliding_window);
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (dtype_id){
case phi::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
return {fmha_out};
}
case phi::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
return {fmha_out};
}
default:
PD_THROW(
switch (dtype_id) {
case phi::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
return {fmha_out};
}
case phi::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
return {fmha_out};
}
default:
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
break;
}
return {paddle::Tensor{}};
@@ -678,60 +730,60 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder,
sliding_window);
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder,
sliding_window);
};
phi::dtype::float16 fp16_dtype;
@@ -769,7 +821,6 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
return {fmha_out};
}
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
@@ -895,8 +946,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
return {paddle::DataType::INT8};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
} else {
PD_THROW(
"Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::BFLOAT16};
@@ -907,8 +959,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
return {paddle::DataType::INT8};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
} else {
PD_THROW(
"Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::FLOAT16};
@@ -1034,8 +1087,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
return {fmha_out_dtype};
}
PD_BUILD_STATIC_OP(append_attention)
.Inputs({"qkv",
"key_cache",
@@ -1074,24 +1125,25 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
.Outputs({"fmha_out"})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.Attrs({
"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
@@ -1136,24 +1188,25 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
paddle::Optional("sinks")})
.Outputs({"fmha_out_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.Attrs({
"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));

View File

@@ -79,7 +79,7 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder,
max_lens[2] = total_max_len_decoder;
max_lens[3] = total;
max_lens[4] = total_just_dec;
max_lens[8] = total_max_len_kv;
max_lens[5] = total_max_len_kv;
}
}
@@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
const int block_size) {
auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0];
@@ -302,10 +301,9 @@ void GetBlockShapeAndSplitKVBlock(
int max_dec_len_this_time = max_len_cpu_ptr[2];
int max_enc_dec_len_this_time = max_len_cpu_ptr[3];
int max_just_dec_len_this_time = max_len_cpu_ptr[4];
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
int max_kv_len_this_time = max_len_cpu_ptr[8];
int max_kv_len_this_time = max_len_cpu_ptr[5];
const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0];
// decoder
if (max_dec_len_this_time > 0) {
@@ -343,25 +341,15 @@ void GetBlockShapeAndSplitKVBlock(
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where *
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
// seq_lens_decoder.dims()[0] * 1024;
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
split_block_for_mla<<<1, 32, 0, stream>>>(
@@ -374,22 +362,15 @@ void GetBlockShapeAndSplitKVBlock(
chunk_size);
} else {
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
// should be taken here
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
@@ -413,13 +394,6 @@ void GetBlockShapeAndSplitKVBlock(
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}
// encoder
@@ -486,8 +460,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
const int block_size) {
return {};
}
@@ -498,8 +471,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
const int block_size) {
return {};
}
@@ -527,8 +499,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Attrs({"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"})
"block_size: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));

View File

@@ -381,8 +381,7 @@ void GetBlockShapeAndSplitKVBlock(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num);
const int block_size);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& token_num,

View File

@@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata):
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -101,7 +98,6 @@ def allocate_launch_related_buffer(
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
return res
@@ -175,10 +171,6 @@ class AppendAttentionBackend(AttentionBackend):
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -263,6 +255,7 @@ class AppendAttentionBackend(AttentionBackend):
cache_v_scales = getattr(layer, "cache_v_scale", None)
if layer.layer_id == 0:
# print(forward_meta.seq_lens_this_time)
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
@@ -283,7 +276,6 @@ class AppendAttentionBackend(AttentionBackend):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
if self.use_output:
@@ -330,7 +322,7 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
@@ -342,8 +334,8 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
res,
metadata.rotary_embs,
metadata.attn_mask,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
cache_k_scales,
@@ -387,7 +379,7 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
@@ -398,8 +390,8 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.rotary_embs,
metadata.attn_mask,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
cache_k_scales,

View File

@@ -213,7 +213,6 @@ class FlashAttentionBackend(AttentionBackend):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
(

View File

@@ -204,13 +204,12 @@ class MLAAttentionBackend(AttentionBackend):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers

View File

@@ -44,7 +44,6 @@ def get_block_shape_and_split_kv_block(
decoder_block_shape_q: int,
group_size: int,
block_size: int,
decoder_step_token_num: int,
):
"""
get_block_shape_and_split_kv_block
@@ -70,7 +69,6 @@ def get_block_shape_and_split_kv_block(
decoder_block_shape_q,
group_size,
block_size,
decoder_step_token_num,
)
else:

View File

@@ -179,13 +179,12 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item()
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers

View File

@@ -628,7 +628,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num + 1,
)
if self.use_dynamic_quant:
cache_quant_type = "block_wise_fp8"

View File

@@ -479,7 +479,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num + 1,
)
# Warm up

View File

@@ -121,10 +121,10 @@ class TestAttentionPerformance(unittest.TestCase):
"dtype": "bfloat16",
"hidden_size": 4096,
"max_position_embeddings": 131072,
"max_model_len": 5500,
"max_model_len": 36 * 1024 + 1024,
"num_attention_heads": 32,
"num_key_value_heads": 4,
"num_hidden_layers": 5,
"num_hidden_layers": 57,
}
model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
config_path = os.path.join(model_dir, "config.json")
@@ -223,7 +223,7 @@ class TestAttentionPerformance(unittest.TestCase):
max_model_len=fd_config.model_config.max_model_len,
encoder_block_shape_q=64,
decoder_block_shape_q=16,
decoder_step_token_num=1,
decoder_step_token_num=fd_config.speculative_config.num_speculative_tokens + 1,
num_heads=fd_config.model_config.num_attention_heads,
kv_num_heads=fd_config.model_config.num_key_value_heads,
block_size=fd_config.cache_config.block_size,
@@ -294,29 +294,30 @@ class TestAttentionPerformance(unittest.TestCase):
def test_decode_performance_with_prefill(self):
# Test parameters
test_steps = 100
prefill_batch_size = 1
prefill_seq_len = 4096
use_dynamic_quant = True
act_tensor_dtype = paddle.bfloat16
prefill_hidden_states = paddle.randn(
[prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size],
dtype=act_tensor_dtype,
)
# prefill_batch_size = 1
# prefill_seq_len = 4096
forward_meta = self.create_forward_meta(
batch_size=prefill_batch_size,
seq_len=prefill_seq_len,
mode=ForwardMode.EXTEND,
fd_config=self.fd_config,
attn_backend=self.attn_backend,
use_dynamic_quant=use_dynamic_quant,
)
# prefill_hidden_states = paddle.randn(
# [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size],
# dtype=act_tensor_dtype,
# )
self.attn_backend.init_attention_metadata(forward_meta)
self.attn_forward(forward_meta, prefill_hidden_states)
# forward_meta = self.create_forward_meta(
# batch_size=prefill_batch_size,
# seq_len=prefill_seq_len,
# mode=ForwardMode.EXTEND,
# fd_config=self.fd_config,
# attn_backend=self.attn_backend,
# use_dynamic_quant=use_dynamic_quant,
# )
paddle.device.synchronize()
# self.attn_backend.init_attention_metadata(forward_meta)
# self.attn_forward(forward_meta, prefill_hidden_states)
# paddle.device.synchronize()
# import paddle.profiler as profiler
# p = profiler.Profiler(
@@ -326,18 +327,18 @@ class TestAttentionPerformance(unittest.TestCase):
# p.start()
# p.step()
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
for i in range(test_steps):
start_events[i].record()
# start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
# end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
# for i in range(test_steps):
# start_events[i].record()
self.attn_forward(forward_meta, prefill_hidden_states)
# self.attn_forward(forward_meta, prefill_hidden_states)
end_events[i].record()
paddle.device.synchronize()
# end_events[i].record()
# paddle.device.synchronize()
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print(times[-5:])
# times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
# print(times[-5:])
# p.stop()
@@ -349,14 +350,14 @@ class TestAttentionPerformance(unittest.TestCase):
# p.start()
# p.step()
for decode_batch_size in [10, 20, 40, 60, 80, 100, 128]:
for decode_batch_size in [32, 16, 8, 4, 2]:
decode_hidden_states = paddle.randn(
[decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype
)
forward_meta = self.create_forward_meta(
batch_size=decode_batch_size,
seq_len=5000,
seq_len=36 * 1024,
mode=ForwardMode.DECODE,
fd_config=self.fd_config,
attn_backend=self.attn_backend,
@@ -383,7 +384,6 @@ class TestAttentionPerformance(unittest.TestCase):
start_events[i].record()
attn_cuda_graphs.replay()
# self.attn_forward(forward_meta, decode_hidden_states)
end_events[i].record()
paddle.device.synchronize()
@@ -391,6 +391,8 @@ class TestAttentionPerformance(unittest.TestCase):
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print(times[-5:])
del forward_meta
# p.stop()

View File

@@ -254,7 +254,6 @@ class TestTreeMask(unittest.TestCase):
decoder_block_shape_q,
self.num_q_head // self.num_kv_head,
self.block_size,
decoder_step_token_num,
)
s_time = 0
for i in range(self.run_time + self.warm_up):