[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)

* Support prefill in Cudagraph

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5

* Solve problem about encoder_num_blocks_x_cpu

* Add early-exit mechanism for attention kernel

* fix test case about append-attention

* Update testcode, Add annotations to related tensors

* move get_input_length_list

* solve test_code

* Add annotations about early-exit for attention kernel

* Add annotations about early-exit for attention kernel2

* solve comment

* solve mtp

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Jundong Liu
2025-09-08 13:12:24 +08:00
committed by GitHub
parent 472402bf4e
commit 3d0aaa5923
21 changed files with 528 additions and 260 deletions

View File

@@ -49,14 +49,6 @@ class AppendAttentionMetadata(AttentionMetadata):
AppendAttentionMetadata
"""
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
@@ -142,15 +134,7 @@ class AppendAttentionBackend(AttentionBackend):
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -158,6 +142,13 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
@@ -288,17 +279,17 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
res,
metadata.rotary_embs,
metadata.attn_mask,
@@ -344,17 +335,17 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,

View File

@@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -198,15 +191,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -214,6 +199,13 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
@@ -295,9 +287,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
@@ -336,17 +328,17 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_num_blocks_cpu,
metadata.max_len_tensor_cpu_decoder,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
metadata.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,

View File

@@ -69,14 +69,6 @@ class MLAAttentionMetadata(AttentionMetadata):
MLAAttentionMetadata for Multi-Layer Attention
"""
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
@@ -191,15 +183,7 @@ class MLAAttentionBackend(AttentionBackend):
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -207,6 +191,13 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
@@ -362,19 +353,19 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
@@ -483,19 +474,19 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales

View File

@@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block(
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks_x_cpu: paddle.Tensor,
max_len_tensor_cpu: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
encoder_tile_ids_per_batch: paddle.Tensor,
encoder_num_blocks_x_cpu: paddle.Tensor,
kv_batch_ids: paddle.Tensor,
kv_tile_ids_per_batch: paddle.Tensor,
kv_num_blocks_x_cpu: paddle.Tensor,
max_len_kv_cpu: paddle.Tensor,
encoder_block_shape_q: int,
decoder_block_shape_q: int,
group_size: int,
@@ -42,15 +49,7 @@ def get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block
"""
if current_platform.is_cuda():
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv_cpu,
) = get_block_shape_and_split_kv_block_cuda(
get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
@@ -58,20 +57,19 @@ def get_block_shape_and_split_kv_block(
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu,
max_len_kv_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
group_size,
block_size,
decoder_step_token_num,
)
return (
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv_cpu,
)
else:
raise NotImplementedError