From 930f7b781c1aecd2247dbdbfdbcd657fe627b21e Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:59:56 +0800 Subject: [PATCH] [Optimization] Put get_block_shape_and_split_kv_block in cuda graph for append attention backend (#4443) * get block in cuda graph * fix sot --- .../get_block_shape_and_split_kv_block.cu | 30 +++++++++++- .../layers/attention/append_attn_backend.py | 46 ++++++++++--------- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 9451a521e..169be8524 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -462,6 +462,32 @@ void GetBlockShapeAndSplitKVBlock( } +std::vector> GetBlockShapeAndSplitKVBlockInferShape( + const std::vector &seq_lens_encoder, + const std::vector &seq_lens_decoder, + const std::vector &seq_lens_this_time, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size, + const int decoder_step_token_num +) { + return {}; +} + +std::vector GetBlockShapeAndSplitKVBlockInferDtype( + const paddle::DataType &seq_lens_encoder, + const paddle::DataType &seq_lens_decoder, + const paddle::DataType &seq_lens_this_time, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size, + const int decoder_step_token_num +) { + return {}; +} + PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) .Inputs({ "seq_lens_encoder", @@ -490,4 +516,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) "block_size: int", "decoder_step_token_num: int" }) - .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)); + .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) + .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index eb0d6dfd0..f88b3de2d 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -134,28 +134,6 @@ class AppendAttentionBackend(AttentionBackend): metadata.rotary_embs = forward_meta.rotary_embs metadata.attn_mask = forward_meta.attn_mask metadata.pre_caches_length = forward_meta.pre_caches_length - get_block_shape_and_split_kv_block( - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.decoder_num_blocks_device, - forward_meta.decoder_chunk_size_device, - forward_meta.max_len_tensor_cpu, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, - forward_meta.kv_batch_ids, - forward_meta.kv_tile_ids_per_batch, - forward_meta.kv_num_blocks_x_cpu, - self.encoder_block_shape_q, - self.decoder_block_shape_q, - self.group_size, - self.block_size, - self.speculate_max_draft_token_num + 1, - ) # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -235,6 +213,30 @@ class AppendAttentionBackend(AttentionBackend): cache_k_scales = getattr(layer, "cache_k_scale", None) cache_v_scales = getattr(layer, "cache_v_scale", None) + if layer.layer_id == 0: + get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, + forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, + self.block_size, + self.speculate_max_draft_token_num + 1, + ) + if self.use_output: quant_max_bound = getattr(layer, "quant_max_bound", 0.0) cache_quant_type = getattr(layer, "cache_quant_type_str", "none")