mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)
* reset decoder_block_shape_q buffer * refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch * update decode_max_tile_size * fix pre-commit * update block_multihead_attn_backend * update flas attn backend * update MLA Attention * update XPU Attention * update gcu,iluvatar model runner * Update MTP * fix MTP bug
This commit is contained in:
@@ -64,17 +64,13 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
@@ -82,8 +78,6 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -105,6 +99,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
MLAAttentionBackend __init__
|
||||
@@ -128,8 +124,11 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
# For Multi Head Latent Attention
|
||||
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
|
||||
@@ -152,8 +151,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
@@ -176,27 +173,25 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
@@ -216,9 +211,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
@@ -354,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
@@ -476,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
|
Reference in New Issue
Block a user