[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)

* Support prefill in Cudagraph

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5

* Solve problem about encoder_num_blocks_x_cpu

* Add early-exit mechanism for attention kernel

* fix test case about append-attention

* Update testcode, Add annotations to related tensors

* move get_input_length_list

* solve test_code

* Add annotations about early-exit for attention kernel

* Add annotations about early-exit for attention kernel2

* solve comment

* solve mtp

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Jundong Liu
2025-09-08 13:12:24 +08:00
committed by GitHub
parent 472402bf4e
commit 3d0aaa5923
21 changed files with 528 additions and 260 deletions

View File

@@ -190,30 +190,32 @@ class TestTreeMask(unittest.TestCase):
encoder_block_shape_q = 64
decoder_block_shape_q = 16
group_size = self.num_q_head // self.num_kv_head
decode_max_tile_size = (
self.bsz
* (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1)
/ decoder_block_shape_q
self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
)
encode_max_tile_size = (
self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q
)
kv_max_tile_size = self.bsz * (self.max_seq_len + self.block_size - 1) / self.block_size
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
q_norm_weight = np.ones([self.head_dim])
k_norm_weight = np.ones([self.head_dim])
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
paddle.device.synchronize()
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
@@ -221,6 +223,13 @@ class TestTreeMask(unittest.TestCase):
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu,
max_len_kv_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
self.num_q_head // self.num_kv_head,
@@ -243,15 +252,15 @@ class TestTreeMask(unittest.TestCase):
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
encoder_num_blocks_x_cpu,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
kv_num_blocks_x_cpu,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
max_len_kv,
max_len_kv_cpu,
rotary_embs,
attn_mask,
None, # qkv_bias