[Optimization] Put get_block_shape_and_split_kv_block in cuda graph for append attention backend (#4443)

* get block in cuda graph

* fix sot
This commit is contained in:
Sunny-bot1
2025-10-17 10:59:56 +08:00
committed by GitHub
parent 49cea8fb1c
commit 930f7b781c
2 changed files with 53 additions and 23 deletions

View File

@@ -462,6 +462,32 @@ void GetBlockShapeAndSplitKVBlock(
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t> &seq_lens_encoder,
const std::vector<int64_t> &seq_lens_decoder,
const std::vector<int64_t> &seq_lens_this_time,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num
) {
return {};
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType &seq_lens_encoder,
const paddle::DataType &seq_lens_decoder,
const paddle::DataType &seq_lens_this_time,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num
) {
return {};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({
"seq_lens_encoder",
@@ -490,4 +516,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"block_size: int",
"decoder_step_token_num: int"
})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));

View File

@@ -134,28 +134,6 @@ class AppendAttentionBackend(AttentionBackend):
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -235,6 +213,30 @@ class AppendAttentionBackend(AttentionBackend):
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)
if layer.layer_id == 0:
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
if self.use_output:
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")