【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)

* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
AIbin
2025-09-11 10:46:09 +08:00
committed by GitHub
parent 637d96c6ae
commit a7392a0ff9
23 changed files with 375 additions and 310 deletions

View File

@@ -128,12 +128,13 @@ struct CollectiveMainloop {
DTypeMD const* d_ptr;
IdType const* kv_block_tables;
IdType const* seq_lens_this_time;
IdType const* seq_lens_encoder;
// IdType const* seq_lens_encoder;
IdType const* seq_lens_decoder;
IdType const* cumsum_q_seqlens;
IdType const* batch_ids;
IdType const* tile_ids_per_batch;
IdType const* num_blocks_x;
IdType const* chunk_size_device;
float sm_scale;
int bsz;
int max_block_num;
@@ -144,7 +145,7 @@ struct CollectiveMainloop {
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
// int chunk_size;
int chunk_num;
int max_draft_token_num;
};
@@ -160,12 +161,13 @@ struct CollectiveMainloop {
DTypeMD* d_ptr;
IdType* kv_block_tables;
IdType* seq_lens_this_time;
IdType* seq_lens_encoder;
// IdType* seq_lens_encoder;
IdType* seq_lens_decoder;
IdType* cumsum_q_seqlens;
IdType* batch_ids;
IdType* tile_ids_per_batch;
IdType* num_blocks_x;
IdType* chunk_size_device;
float sm_scale;
int bsz;
int max_block_num;
@@ -176,7 +178,7 @@ struct CollectiveMainloop {
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
// int chunk_size;
int chunk_num;
int max_draft_token_num;
TMA_KV tma_load_KV;
@@ -198,12 +200,13 @@ struct CollectiveMainloop {
const_cast<DTypeMD*>(args.d_ptr),
const_cast<IdType*>(args.kv_block_tables),
const_cast<IdType*>(args.seq_lens_this_time),
const_cast<IdType*>(args.seq_lens_encoder),
// const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_decoder),
const_cast<IdType*>(args.cumsum_q_seqlens),
const_cast<IdType*>(args.batch_ids),
const_cast<IdType*>(args.tile_ids_per_batch),
const_cast<IdType*>(args.num_blocks_x),
const_cast<IdType*>(args.chunk_size_device),
args.sm_scale,
args.bsz,
args.max_block_num,
@@ -214,7 +217,7 @@ struct CollectiveMainloop {
args.kv_stride_block_size,
args.o_stride_bsz,
args.o_stride_head_num,
args.chunk_size,
// args.chunk_size,
args.chunk_num,
args.max_draft_token_num,
tma_load_KV
@@ -281,9 +284,9 @@ struct CollectiveMainloop {
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
@@ -322,9 +325,9 @@ struct CollectiveMainloop {
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));