mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-27 02:20:31 +08:00
[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)
* Support prefill in Cudagraph * Refactor GetBlockShapeAndSplitKVBlock Kernel V2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5 * Solve problem about encoder_num_blocks_x_cpu * Add early-exit mechanism for attention kernel * fix test case about append-attention * Update testcode, Add annotations to related tensors * move get_input_length_list * solve test_code * Add annotations about early-exit for attention kernel * Add annotations about early-exit for attention kernel2 * solve comment * solve mtp --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -81,14 +81,42 @@ class ForwardMeta:
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
# Attention mask offset
|
||||
attn_mask_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
# A common pattern for launching CUDA kernels is to set the kernel's grids.x dimension
|
||||
# using a `num_blocks` variable, and then map each thread block to a specific batch and
|
||||
# data tile using `batch_ids` and `tile_ids_per_batch`.
|
||||
#
|
||||
# The variable names below follow this pattern, using a common prefix (e.g., `encoder_`, `decoder_`, `kv_`)
|
||||
# for variables that are logically grouped together. The mapping works as follows:
|
||||
#
|
||||
# Usage: `my_kernel<<<grids, ...>>>(..., batch_ids, tile_ids, ...)`
|
||||
# `grids.x` = `num_blocks_cpu`
|
||||
# `batch_id` = `batch_ids[blockIdx.x]`
|
||||
# `tile_id` = `tile_ids[blockIdx.x]`
|
||||
|
||||
# Maps the thread block index (blockIdx.x) to the corresponding batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||
# Decoder batch id. Used by attention backend.
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Tile ID for each batch of the decoder. Used by attention backend.
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of blocks that attention backend can use in decode stage
|
||||
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
|
||||
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||
# Recorded multiple lengths related to prefill or decode
|
||||
# A tensor that holds multiple lengths related to prefill or decode stages.
|
||||
max_len_tensor_cpu: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the corresponding batch for the encoder stage in multi_query_append_attention_kernel.
|
||||
encoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the encoder stage in multi_query_append_attention_kernel.
|
||||
encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_kernel, defining its grids.x.
|
||||
encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the corresponding batch for the append_write_cache_kv kernel.
|
||||
kv_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the append_write_cache_kv kernel.
|
||||
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x.
|
||||
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
|
||||
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
|
||||
max_len_kv_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
# Sequence length of encoder for ever batch
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
@@ -133,6 +161,7 @@ class ForwardMeta:
|
||||
"shape": obj.shape,
|
||||
"dtype": str(obj.dtype),
|
||||
"place": str(obj.place),
|
||||
# "content": obj if obj.numel()<10 else "Too big to show"
|
||||
}
|
||||
return tensor_info
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
|
||||
@@ -49,14 +49,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
AppendAttentionMetadata
|
||||
"""
|
||||
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
@@ -142,15 +134,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -158,6 +142,13 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
@@ -288,17 +279,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
res,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
@@ -344,17 +335,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
|
||||
@@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
@@ -198,15 +191,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -214,6 +199,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
@@ -295,9 +287,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
@@ -336,17 +328,17 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids, # from buffer
|
||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_len_tensor_cpu_decoder,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
|
||||
@@ -69,14 +69,6 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
@@ -191,15 +183,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -207,6 +191,13 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
@@ -362,19 +353,19 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
@@ -483,19 +474,19 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
|
||||
@@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
max_len_tensor_cpu: paddle.Tensor,
|
||||
encoder_batch_ids: paddle.Tensor,
|
||||
encoder_tile_ids_per_batch: paddle.Tensor,
|
||||
encoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
kv_batch_ids: paddle.Tensor,
|
||||
kv_tile_ids_per_batch: paddle.Tensor,
|
||||
kv_num_blocks_x_cpu: paddle.Tensor,
|
||||
max_len_kv_cpu: paddle.Tensor,
|
||||
encoder_block_shape_q: int,
|
||||
decoder_block_shape_q: int,
|
||||
group_size: int,
|
||||
@@ -42,15 +49,7 @@ def get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv_cpu,
|
||||
) = get_block_shape_and_split_kv_block_cuda(
|
||||
get_block_shape_and_split_kv_block_cuda(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
@@ -58,20 +57,19 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu,
|
||||
max_len_tensor_cpu,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu,
|
||||
max_len_kv_cpu,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
group_size,
|
||||
block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
return (
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv_cpu,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user