[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)

* reset decoder_block_shape_q buffer

* refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch

* update decode_max_tile_size

* fix pre-commit

* update block_multihead_attn_backend

* update flas attn backend

* update MLA Attention

* update XPU Attention

* update gcu,iluvatar model runner

* Update MTP

* fix MTP bug
This commit is contained in:
RAM
2025-07-31 00:09:31 +08:00
committed by GitHub
parent 998968f1e8
commit d850660872
13 changed files with 222 additions and 235 deletions

View File

@@ -184,13 +184,26 @@ class MTPProposer(Proposer):
"""
assert len(self.attn_backends) == 0
# TODO(gongshaotian): Get rank from config
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
self.model_config.kv_num_heads = (
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size
self.model_config.kv_num_heads = max(
1,
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
)
head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.main_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.main_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -198,6 +211,8 @@ class MTPProposer(Proposer):
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)
if attn_backend is None:
raise NotImplementedError(
@@ -293,6 +308,12 @@ class MTPProposer(Proposer):
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0
# Declare AttentionBackend buffers
self.model_inputs["decoder_batch_ids"] = None
self.model_inputs["decoder_tile_ids_per_batch"] = None
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.model_inputs["max_len_tensor_cpu"] = None # CPU
# Input tokens
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
@@ -405,6 +426,8 @@ class MTPProposer(Proposer):
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],