From 6fa34102e8f26bd6a44748068539e5ab221284b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:40:04 +0800 Subject: [PATCH] [Others]get_block_shape_and_split_kv_block clean code (#5123) --- custom_ops/gpu_ops/append_attention.cu | 563 ++++++++++-------- .../get_block_shape_and_split_kv_block.cu | 53 +- custom_ops/gpu_ops/cpp_extensions.cc | 3 +- .../layers/attention/append_attn_backend.py | 22 +- .../layers/attention/flash_attn_backend.py | 1 - .../layers/attention/mla_attention_backend.py | 3 +- .../ops/get_block_shape_and_split_kv_block.py | 2 - .../metax/attention/mla_attn_metax_backend.py | 3 +- tests/layers/test_append_attention.py | 1 - .../test_append_attention_with_output.py | 1 - tests/layers/test_attention_layer.py | 66 +- tests/operators/test_tree_mask.py | 1 - 12 files changed, 364 insertions(+), 355 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 2ebcbfc2e..7982f6142 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -14,8 +14,8 @@ #include "append_attn/append_attention_kernel.h" #include "append_attn/decoder_write_cache_with_rope_kernel.h" -#include "append_attn/speculate_write_cache_with_rope_kernel.h" #include "append_attn/encoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -26,17 +26,16 @@ class type2value; template <> class type2value { - public: - static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; }; template <> class type2value { - public: - static constexpr paddle::DataType value = paddle::DataType::FLOAT16; + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; }; - template void AppendAttentionKernel( const AppendAttnMetaData& meta_data, @@ -96,14 +95,12 @@ void AppendAttentionKernel( typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; - // set_max_lengths: max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time, - // max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system - int max_len_this_time = set_max_lengths.data()[0]; - int max_enc_len_this_time =set_max_lengths.data()[1]; - int max_dec_len_this_time = set_max_lengths.data()[2]; - int max_enc_dec_len_this_time = set_max_lengths.data()[3]; - int max_just_dec_len_this_time = set_max_lengths.data()[4]; - int max_kv_len_this_time = set_max_lengths.data()[8]; + const int max_len_this_time = set_max_lengths.data()[0]; + const int max_enc_len_this_time = set_max_lengths.data()[1]; + const int max_dec_len_this_time = set_max_lengths.data()[2]; + const int max_enc_dec_len_this_time = set_max_lengths.data()[3]; + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; auto main_stream = qkv.stream(); static cudaEvent_t main_event; @@ -125,54 +122,56 @@ void AppendAttentionKernel( qkv_out = qkv; } - auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args, - const paddle::Tensor& lambda_batch_ids, - const paddle::Tensor& lambda_tile_ids_per_batch, - const int lambda_num_blocks_data, - const int lambda_block_shape_q, - const int lambda_max_dec_len, - const bool lambda_is_decoder, - const bool lambda_enable_prefill, - cudaStream_t& lambda_stream - ) -> void { - CascadeAppendAttentionKernel( - meta_data, - qkv_out, - key_cache, - value_cache, - attn_mask, - cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales, - cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - sinks, - seq_lens_this_time, - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - lambda_batch_ids, - lambda_tile_ids_per_batch, - cache_quant_type_str, - lambda_num_blocks_data, - lambda_block_shape_q, - max_input_length, - lambda_max_dec_len, - quant_max_bound, - quant_min_bound, - out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - lambda_is_decoder, - lambda_enable_prefill, - lambda_stream, - &fmha_out, - sliding_window); + auto dispatch_CascadeAppendAttentionKernel = + [&](auto temp_args, + const paddle::Tensor& lambda_batch_ids, + const paddle::Tensor& lambda_tile_ids_per_batch, + const int lambda_num_blocks_data, + const int lambda_block_shape_q, + const int lambda_max_dec_len, + const bool lambda_is_decoder, + const bool lambda_enable_prefill, + cudaStream_t& lambda_stream) -> void { + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales + : cache_k_dequant_scales, + cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales + : cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + lambda_batch_ids, + lambda_tile_ids_per_batch, + cache_quant_type_str, + lambda_num_blocks_data, + lambda_block_shape_q, + max_input_length, + lambda_max_dec_len, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + lambda_is_decoder, + lambda_enable_prefill, + lambda_stream, + &fmha_out, + sliding_window); }; if (max_enc_len_this_time > 0) { @@ -182,8 +181,9 @@ void AppendAttentionKernel( int encoder_num_blocks_data = encoder_num_blocks.data()[0]; int kv_num_blocks_data = kv_num_blocks.data()[0]; - auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void { - EncoderWriteCacheWithRopeKernel( + auto dispatch_EncoderWriteCacheWithRopeKernel = + [&](auto temp_args) -> void { + EncoderWriteCacheWithRopeKernel( meta_data, qkv, seq_lens_this_time, @@ -225,24 +225,50 @@ void AppendAttentionKernel( } if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { - case paddle::DataType::INT8:{ + case paddle::DataType::INT8: { int8_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream); + dispatch_CascadeAppendAttentionKernel(tmp, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_data, + encoder_block_shape_q, + max_enc_dec_len_this_time, + false, + true, + main_stream); break; } - case paddle::DataType::FLOAT8_E4M3FN:{ + case paddle::DataType::FLOAT8_E4M3FN: { phi::dtype::float8_e4m3fn tmp; - dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream); + dispatch_CascadeAppendAttentionKernel(tmp, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_data, + encoder_block_shape_q, + max_enc_dec_len_this_time, + false, + true, + main_stream); break; } - default:{ - PD_THROW("Only supported output fmha_out of quant dtype in ['int8', 'FLOAT8_E4M3FN']."); + default: { + PD_THROW( + "Only supported output fmha_out of quant dtype in ['int8', " + "'FLOAT8_E4M3FN']."); break; } } } else { data_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream); + dispatch_CascadeAppendAttentionKernel(tmp, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_data, + encoder_block_shape_q, + max_enc_dec_len_this_time, + false, + true, + main_stream); } } @@ -370,23 +396,44 @@ void AppendAttentionKernel( if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { - case paddle::DataType::INT8:{ - int8_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); + case paddle::DataType::INT8: { + int8_t tmp; + dispatch_CascadeAppendAttentionKernel(tmp, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_data, + decoder_block_shape_q, + max_kv_len_this_time, + !speculate_decoder, + !speculate_decoder, + exec_stream); break; } - case paddle::DataType::FLOAT8_E4M3FN:{ - phi::dtype::float8_e4m3fn tmp; - dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); + case paddle::DataType::FLOAT8_E4M3FN: { + phi::dtype::float8_e4m3fn tmp; + dispatch_CascadeAppendAttentionKernel(tmp, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_data, + decoder_block_shape_q, + max_kv_len_this_time, + !speculate_decoder, + !speculate_decoder, + exec_stream); break; } } } else { - data_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); + data_t tmp; + dispatch_CascadeAppendAttentionKernel(tmp, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_data, + decoder_block_shape_q, + max_kv_len_this_time, + !speculate_decoder, + !speculate_decoder, + exec_stream); } if (max_enc_len_this_time > 0) { cudaEventRecord(decoder_event, exec_stream); @@ -471,8 +518,14 @@ std::vector AppendAttention( // template dtype generation phi::DataType dtype_id; switch (qkv.dtype()) { - case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;} - case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;} + case paddle::DataType::FLOAT16: { + dtype_id = phi::DataType::FLOAT16; + break; + } + case paddle::DataType::BFLOAT16: { + dtype_id = phi::DataType::BFLOAT16; + break; + } case paddle::DataType::INT32: { if (compute_dtype == "bf16") { dtype_id = phi::DataType::BFLOAT16; @@ -498,15 +551,15 @@ std::vector AppendAttention( if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { fmha_out = paddle::zeros( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::INT8, - qkv.place()); + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::INT8, + qkv.place()); } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { fmha_out = paddle::zeros( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::FLOAT8_E4M3FN, - qkv.place()); - } else{ + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::FLOAT8_E4M3FN, + qkv.place()); + } else { PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); } } else { @@ -521,79 +574,78 @@ std::vector AppendAttention( } auto dispatch_by_template = [&](auto temp_args) -> void { - AppendAttentionKernel::value>( - meta_data, - qkv, - key_cache, - value_cache, - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - batch_id_per_token, - cu_seqlens_q, - block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks, - set_max_lengths, - fmha_out, - rotary_embs, - attn_mask, - qkv_bias, - qkv_out_scales, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_dequant_scales, - cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - kv_signal_data, - q_norm_weight, - k_norm_weight, - sinks, - rms_norm_eps, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - quant_max_bound, - quant_min_bound, - out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - speculate_decoder, - sliding_window); + AppendAttentionKernel::value>( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + fmha_out, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + kv_signal_data, + q_norm_weight, + k_norm_weight, + sinks, + rms_norm_eps, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder, + sliding_window); }; - phi::dtype::float16 fp16_dtype; phi::dtype::bfloat16 bp16_dtype; - switch (dtype_id){ - case phi::DataType::FLOAT16: { - dispatch_by_template(fp16_dtype); - return {fmha_out}; - } - case phi::DataType::BFLOAT16: { - dispatch_by_template(bp16_dtype); - return {fmha_out}; - } - default: - PD_THROW( + switch (dtype_id) { + case phi::DataType::FLOAT16: { + dispatch_by_template(fp16_dtype); + return {fmha_out}; + } + case phi::DataType::BFLOAT16: { + dispatch_by_template(bp16_dtype); + return {fmha_out}; + } + default: + PD_THROW( "NOT supported data type. " "Only float16 and bfloat16 are supported. "); - break; + break; } return {paddle::Tensor{}}; @@ -678,60 +730,60 @@ std::vector AppendAttentionWithOutput( } auto dispatch_by_template = [&](auto temp_args) -> void { - AppendAttentionKernel::value>( - meta_data, - qkv, - key_cache, - value_cache, - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - batch_id_per_token, - cu_seqlens_q, - block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks, - set_max_lengths, - fmha_out, - rotary_embs, - attn_mask, - qkv_bias, - qkv_out_scales, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_dequant_scales, - cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - kv_signal_data, - q_norm_weight, - k_norm_weight, - sinks, - rms_norm_eps, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - quant_max_bound, - quant_min_bound, - out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - speculate_decoder, - sliding_window); + AppendAttentionKernel::value>( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + fmha_out, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + kv_signal_data, + q_norm_weight, + k_norm_weight, + sinks, + rms_norm_eps, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder, + sliding_window); }; phi::dtype::float16 fp16_dtype; @@ -769,7 +821,6 @@ std::vector AppendAttentionWithOutput( return {fmha_out}; } - std::vector> AppendAttentionInferShape( const std::vector& qkv_shape, const std::vector& key_cache_shape, @@ -895,8 +946,9 @@ std::vector AppendAttentionInferDtype( return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { return {paddle::DataType::FLOAT8_E4M3FN}; - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } else { + PD_THROW( + "Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { return {paddle::DataType::BFLOAT16}; @@ -907,8 +959,9 @@ std::vector AppendAttentionInferDtype( return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { return {paddle::DataType::FLOAT8_E4M3FN}; - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } else { + PD_THROW( + "Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { return {paddle::DataType::FLOAT16}; @@ -1034,8 +1087,6 @@ std::vector AppendAttentionWithOutputInferDtype( return {fmha_out_dtype}; } - - PD_BUILD_STATIC_OP(append_attention) .Inputs({"qkv", "key_cache", @@ -1074,24 +1125,25 @@ PD_BUILD_STATIC_OP(append_attention) paddle::Optional("k_norm_weight"), paddle::Optional("sinks")}) .Outputs({"fmha_out"}) - .Attrs({"rms_norm_eps: float", - "compute_type: std::string", - "cache_quant_type: std::string", - "use_neox_rotary_style: bool", - "rope_3d: bool", - "max_input_length: int", - "quant_max_bound: float", - "quant_min_bound: float", - "out_linear_in_scale: float", - "encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "max_partition_size: int", - "encoder_max_partition_size: int", - "speculate_max_draft_token_num: int", - "causal: bool", - "speculate_decoder: bool", - "sliding_window: int", - }) + .Attrs({ + "rms_norm_eps: float", + "compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool", + "sliding_window: int", + }) .SetKernelFn(PD_KERNEL(AppendAttention)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); @@ -1136,24 +1188,25 @@ PD_BUILD_STATIC_OP(append_attention_with_output) paddle::Optional("sinks")}) .Outputs({"fmha_out_out"}) .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) - .Attrs({"rms_norm_eps: float", - "compute_type: std::string", - "cache_quant_type: std::string", - "use_neox_rotary_style: bool", - "rope_3d: bool", - "max_input_length: int", - "quant_max_bound: float", - "quant_min_bound: float", - "out_linear_in_scale: float", - "encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "max_partition_size: int", - "encoder_max_partition_size: int", - "speculate_max_draft_token_num: int", - "causal: bool", - "speculate_decoder: bool", - "sliding_window: int", - }) + .Attrs({ + "rms_norm_eps: float", + "compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool", + "sliding_window: int", + }) .SetKernelFn(PD_KERNEL(AppendAttentionWithOutput)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype)); diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 3368eb620..e84f82816 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -79,7 +79,7 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder, max_lens[2] = total_max_len_decoder; max_lens[3] = total; max_lens[4] = total_just_dec; - max_lens[8] = total_max_len_kv; + max_lens[5] = total_max_len_kv; } } @@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num) { + const int block_size) { auto stream = seq_lens_encoder.stream(); int bsz = seq_lens_this_time.shape()[0]; @@ -302,10 +301,9 @@ void GetBlockShapeAndSplitKVBlock( int max_dec_len_this_time = max_len_cpu_ptr[2]; int max_enc_dec_len_this_time = max_len_cpu_ptr[3]; int max_just_dec_len_this_time = max_len_cpu_ptr[4]; - int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5]; - int max_system_len = max_len_cpu_ptr[6]; - int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - int max_kv_len_this_time = max_len_cpu_ptr[8]; + int max_kv_len_this_time = max_len_cpu_ptr[5]; + + const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0]; // decoder if (max_dec_len_this_time > 0) { @@ -343,25 +341,15 @@ void GetBlockShapeAndSplitKVBlock( decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false); const int chunk_size = decoder_chunk_size_cpu.data()[0]; - // NOTE: (changwenbin) When using auto_chunk, - // decode_max_tile_size must take into account the maximum case, where * - // 1024 can cover 128K. const uint32_t decoder_batch_shape = - // seq_lens_decoder.dims()[0] * 1024; - - const uint32_t decoder_max_tile_size_per_bs_q = - div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = - bsz * 1024 * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_batch_ids.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); split_block_for_mla<<<1, 32, 0, stream>>>( @@ -374,22 +362,15 @@ void GetBlockShapeAndSplitKVBlock( chunk_size); } else { - // Note:(changwenbin)In order to adapt to cudagraph, the maximum value - // should be taken here - const uint32_t decoder_max_tile_size_per_bs_q = - div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = - bsz * 1024 * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_batch_ids.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); @@ -413,13 +394,6 @@ void GetBlockShapeAndSplitKVBlock( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } - } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); - decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); } // encoder @@ -486,8 +460,7 @@ std::vector> GetBlockShapeAndSplitKVBlockInferShape( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num) { + const int block_size) { return {}; } @@ -498,8 +471,7 @@ std::vector GetBlockShapeAndSplitKVBlockInferDtype( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num) { + const int block_size) { return {}; } @@ -527,8 +499,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) .Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int", "group_size: int", - "block_size: int", - "decoder_step_token_num: int"}) + "block_size: int"}) .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 6ecc1ed14..0e6853d9b 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -381,8 +381,7 @@ void GetBlockShapeAndSplitKVBlock( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num); + const int block_size); std::vector GetPaddingOffset(const paddle::Tensor& input_ids, const paddle::Tensor& token_num, diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 3e3a56aa4..ac6231a58 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata): _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 - block_tables: Optional[paddle.Tensor] = None - rotary_embs: Optional[paddle.Tensor] = None - attn_mask: Optional[paddle.Tensor] = None _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -101,7 +98,6 @@ def allocate_launch_related_buffer( res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - return res @@ -175,10 +171,6 @@ class AppendAttentionBackend(AttentionBackend): metadata._fuse_kernel_compute_dtype = "fp16" elif metadata._dtype == "float32": metadata._fuse_kernel_compute_dtype = "fp32" - metadata.block_tables = forward_meta.block_tables - metadata.rotary_embs = forward_meta.rotary_embs - metadata.attn_mask = forward_meta.attn_mask - metadata.pre_caches_length = forward_meta.pre_caches_length # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -263,6 +255,7 @@ class AppendAttentionBackend(AttentionBackend): cache_v_scales = getattr(layer, "cache_v_scale", None) if layer.layer_id == 0: + # print(forward_meta.seq_lens_this_time) get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, @@ -283,7 +276,6 @@ class AppendAttentionBackend(AttentionBackend): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) if self.use_output: @@ -330,7 +322,7 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - metadata.block_tables, + forward_meta.block_tables, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, forward_meta.encoder_num_blocks_x_cpu, @@ -342,8 +334,8 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, res, - metadata.rotary_embs, - metadata.attn_mask, + forward_meta.rotary_embs, + forward_meta.attn_mask, layer.qkv_bias, layer.qkv_scale, cache_k_scales, @@ -387,7 +379,7 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - metadata.block_tables, + forward_meta.block_tables, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, forward_meta.encoder_num_blocks_x_cpu, @@ -398,8 +390,8 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - metadata.rotary_embs, - metadata.attn_mask, + forward_meta.rotary_embs, + forward_meta.attn_mask, layer.qkv_bias, layer.qkv_scale, cache_k_scales, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index bce361eb5..31d6d7488 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -213,7 +213,6 @@ class FlashAttentionBackend(AttentionBackend): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) ( diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 54e72379e..8df65d39d 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -204,13 +204,12 @@ class MLAAttentionBackend(AttentionBackend): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] - metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8] + metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5] # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index 1cd5f4f14..a97cf1666 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -44,7 +44,6 @@ def get_block_shape_and_split_kv_block( decoder_block_shape_q: int, group_size: int, block_size: int, - decoder_step_token_num: int, ): """ get_block_shape_and_split_kv_block @@ -70,7 +69,6 @@ def get_block_shape_and_split_kv_block( decoder_block_shape_q, group_size, block_size, - decoder_step_token_num, ) else: diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index ff1bce8bd..8800d497e 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -179,13 +179,12 @@ class MetaxMLAAttentionBackend(AttentionBackend): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item() metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] - metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8] + metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5] # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 4cc00858d..01ad4bb93 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -628,7 +628,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, self.blocksize, - speculate_max_draft_token_num + 1, ) if self.use_dynamic_quant: cache_quant_type = "block_wise_fp8" diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index 5f08c7371..6c15de17c 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -479,7 +479,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, self.blocksize, - speculate_max_draft_token_num + 1, ) # Warm up diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 32d579f74..91bd43eb6 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -121,10 +121,10 @@ class TestAttentionPerformance(unittest.TestCase): "dtype": "bfloat16", "hidden_size": 4096, "max_position_embeddings": 131072, - "max_model_len": 5500, + "max_model_len": 36 * 1024 + 1024, "num_attention_heads": 32, "num_key_value_heads": 4, - "num_hidden_layers": 5, + "num_hidden_layers": 57, } model_dir = tempfile.mkdtemp(prefix="tmp_model_config_") config_path = os.path.join(model_dir, "config.json") @@ -223,7 +223,7 @@ class TestAttentionPerformance(unittest.TestCase): max_model_len=fd_config.model_config.max_model_len, encoder_block_shape_q=64, decoder_block_shape_q=16, - decoder_step_token_num=1, + decoder_step_token_num=fd_config.speculative_config.num_speculative_tokens + 1, num_heads=fd_config.model_config.num_attention_heads, kv_num_heads=fd_config.model_config.num_key_value_heads, block_size=fd_config.cache_config.block_size, @@ -294,29 +294,30 @@ class TestAttentionPerformance(unittest.TestCase): def test_decode_performance_with_prefill(self): # Test parameters test_steps = 100 - prefill_batch_size = 1 - prefill_seq_len = 4096 use_dynamic_quant = True act_tensor_dtype = paddle.bfloat16 - prefill_hidden_states = paddle.randn( - [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size], - dtype=act_tensor_dtype, - ) + # prefill_batch_size = 1 + # prefill_seq_len = 4096 - forward_meta = self.create_forward_meta( - batch_size=prefill_batch_size, - seq_len=prefill_seq_len, - mode=ForwardMode.EXTEND, - fd_config=self.fd_config, - attn_backend=self.attn_backend, - use_dynamic_quant=use_dynamic_quant, - ) + # prefill_hidden_states = paddle.randn( + # [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size], + # dtype=act_tensor_dtype, + # ) - self.attn_backend.init_attention_metadata(forward_meta) - self.attn_forward(forward_meta, prefill_hidden_states) + # forward_meta = self.create_forward_meta( + # batch_size=prefill_batch_size, + # seq_len=prefill_seq_len, + # mode=ForwardMode.EXTEND, + # fd_config=self.fd_config, + # attn_backend=self.attn_backend, + # use_dynamic_quant=use_dynamic_quant, + # ) - paddle.device.synchronize() + # self.attn_backend.init_attention_metadata(forward_meta) + # self.attn_forward(forward_meta, prefill_hidden_states) + + # paddle.device.synchronize() # import paddle.profiler as profiler # p = profiler.Profiler( @@ -326,18 +327,18 @@ class TestAttentionPerformance(unittest.TestCase): # p.start() # p.step() - start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] - end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] - for i in range(test_steps): - start_events[i].record() + # start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] + # end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] + # for i in range(test_steps): + # start_events[i].record() - self.attn_forward(forward_meta, prefill_hidden_states) + # self.attn_forward(forward_meta, prefill_hidden_states) - end_events[i].record() - paddle.device.synchronize() + # end_events[i].record() + # paddle.device.synchronize() - times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] - print(times[-5:]) + # times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] + # print(times[-5:]) # p.stop() @@ -349,14 +350,14 @@ class TestAttentionPerformance(unittest.TestCase): # p.start() # p.step() - for decode_batch_size in [10, 20, 40, 60, 80, 100, 128]: + for decode_batch_size in [32, 16, 8, 4, 2]: decode_hidden_states = paddle.randn( [decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype ) forward_meta = self.create_forward_meta( batch_size=decode_batch_size, - seq_len=5000, + seq_len=36 * 1024, mode=ForwardMode.DECODE, fd_config=self.fd_config, attn_backend=self.attn_backend, @@ -383,7 +384,6 @@ class TestAttentionPerformance(unittest.TestCase): start_events[i].record() attn_cuda_graphs.replay() - # self.attn_forward(forward_meta, decode_hidden_states) end_events[i].record() paddle.device.synchronize() @@ -391,6 +391,8 @@ class TestAttentionPerformance(unittest.TestCase): times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] print(times[-5:]) + del forward_meta + # p.stop() diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 1cfbaaf7a..57a620448 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -254,7 +254,6 @@ class TestTreeMask(unittest.TestCase): decoder_block_shape_q, self.num_q_head // self.num_kv_head, self.block_size, - decoder_step_token_num, ) s_time = 0 for i in range(self.run_time + self.warm_up):