[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)

* reset decoder_block_shape_q buffer

* refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch

* update decode_max_tile_size

* fix pre-commit

* update block_multihead_attn_backend

* update flas attn backend

* update MLA Attention

* update XPU Attention

* update gcu,iluvatar model runner

* Update MTP

* fix MTP bug
This commit is contained in:
RAM
2025-07-31 00:09:31 +08:00
committed by GitHub
parent 998968f1e8
commit d850660872
13 changed files with 222 additions and 235 deletions

View File

@@ -44,26 +44,13 @@ class XPUAttentionMetadata(AttentionMetadata):
XPUAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -91,7 +78,6 @@ class XPUAttentionBackend(AttentionBackend):
"""
super().__init__()
self.attention_metadata: XPUAttentionMetadata = None
# TODO(gongshaotian): Use fd_config parameters in the correct location
self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (
@@ -99,9 +85,6 @@ class XPUAttentionBackend(AttentionBackend):
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
# self.speculate_method = fd_config.parallel_config.speculate_method
# self.use_speculate = self.speculate_method is not None
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
@@ -117,8 +100,6 @@ class XPUAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = XPUAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = 32768
metadata._dtype = paddle.get_default_dtype()