[Optimization] Put get_block_shape_and_split_kv_block in cuda graph for append attention backend (#4443)

* get block in cuda graph

* fix sot
This commit is contained in:
Sunny-bot1
2025-10-17 10:59:56 +08:00
committed by GitHub
parent 49cea8fb1c
commit 930f7b781c
2 changed files with 53 additions and 23 deletions

View File

@@ -134,28 +134,6 @@ class AppendAttentionBackend(AttentionBackend):
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -235,6 +213,30 @@ class AppendAttentionBackend(AttentionBackend):
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)
if layer.layer_id == 0:
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
if self.use_output:
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")