mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
@@ -99,6 +99,8 @@ class ForwardMeta:
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of blocks that attention backend can use in decode stage
|
||||
decoder_num_blocks_device: Optional[paddle.Tensor] = None
|
||||
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
|
||||
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||
# A tensor that holds multiple lengths related to prefill or decode stages.
|
||||
@@ -118,6 +120,8 @@ class ForwardMeta:
|
||||
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
|
||||
max_len_kv_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
decoder_chunk_size_device: Optional[paddle.Tensor] = None
|
||||
|
||||
# Sequence length of encoder for ever batch
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
# Sequence length of Encoder for ever batch
|
||||
|
@@ -141,6 +141,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
|
@@ -198,6 +198,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
|
@@ -187,9 +187,11 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
|
||||
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
@@ -347,23 +349,18 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
metadata.max_dec_len_this_time,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
@@ -468,23 +465,18 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
metadata.max_dec_len_this_time,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
|
@@ -30,7 +30,9 @@ def get_block_shape_and_split_kv_block(
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
decoder_batch_ids: paddle.Tensor,
|
||||
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
decoder_num_blocks_cpu: paddle.Tensor,
|
||||
decoder_num_blocks_device: paddle.Tensor,
|
||||
decoder_chunk_size_device: paddle.Tensor,
|
||||
max_len_tensor_cpu: paddle.Tensor,
|
||||
encoder_batch_ids: paddle.Tensor,
|
||||
encoder_tile_ids_per_batch: paddle.Tensor,
|
||||
@@ -55,7 +57,9 @@ def get_block_shape_and_split_kv_block(
|
||||
seq_lens_this_time,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu,
|
||||
decoder_num_blocks_cpu,
|
||||
decoder_num_blocks_device,
|
||||
decoder_chunk_size_device,
|
||||
max_len_tensor_cpu,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
|
Reference in New Issue
Block a user