[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)

* reset decoder_block_shape_q buffer

* refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch

* update decode_max_tile_size

* fix pre-commit

* update block_multihead_attn_backend

* update flas attn backend

* update MLA Attention

* update XPU Attention

* update gcu,iluvatar model runner

* Update MTP

* fix MTP bug
This commit is contained in:
RAM
2025-07-31 00:09:31 +08:00
committed by GitHub
parent 998968f1e8
commit d850660872
13 changed files with 222 additions and 235 deletions

View File

@@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata):
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
@@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend):
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
FlashAttentionBackend __init__
@@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend):
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
@@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
@@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
@@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, int(self.device_id), self.keep_pd_step_flag
)
self.attention_metadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def forward_mixed(
self,
@@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=metadata.set_max_lengths[0],
max_seqlen_k=metadata.set_max_lengths[3],
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp])