【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)

* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
AIbin
2025-09-11 10:46:09 +08:00
committed by GitHub
parent 637d96c6ae
commit a7392a0ff9
23 changed files with 375 additions and 310 deletions

View File

@@ -141,6 +141,8 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,

View File

@@ -198,6 +198,8 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,

View File

@@ -187,9 +187,11 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
@@ -347,23 +349,18 @@ class MLAAttentionBackend(AttentionBackend):
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
metadata.max_dec_len_this_time,
forward_meta.max_len_kv_cpu,
None, # attn_mask
@@ -468,23 +465,18 @@ class MLAAttentionBackend(AttentionBackend):
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
metadata.max_dec_len_this_time,
forward_meta.max_len_kv_cpu,
None, # attn_mask

View File

@@ -30,7 +30,9 @@ def get_block_shape_and_split_kv_block(
seq_lens_this_time: paddle.Tensor,
decoder_batch_ids: paddle.Tensor,
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks_x_cpu: paddle.Tensor,
decoder_num_blocks_cpu: paddle.Tensor,
decoder_num_blocks_device: paddle.Tensor,
decoder_chunk_size_device: paddle.Tensor,
max_len_tensor_cpu: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
encoder_tile_ids_per_batch: paddle.Tensor,
@@ -55,7 +57,9 @@ def get_block_shape_and_split_kv_block(
seq_lens_this_time,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu,
decoder_num_blocks_cpu,
decoder_num_blocks_device,
decoder_chunk_size_device,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,