mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)
* reset decoder_block_shape_q buffer * refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch * update decode_max_tile_size * fix pre-commit * update block_multihead_attn_backend * update flas attn backend * update MLA Attention * update XPU Attention * update gcu,iluvatar model runner * Update MTP * fix MTP bug
This commit is contained in:
@@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
@@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
@@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
@@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.attn_outputsize_tp = self.num_heads * self.head_dim
|
||||
self.block_size = fd_config.cache_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
@@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
@@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
@@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag
|
||||
)
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
@@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.set_max_lengths[0],
|
||||
max_seqlen_k=metadata.set_max_lengths[3],
|
||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||
|
Reference in New Issue
Block a user