mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 10:00:29 +08:00
【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
@@ -62,13 +62,12 @@ struct Params {
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
@@ -76,7 +75,7 @@ struct Params {
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
@@ -96,9 +95,7 @@ struct Params {
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
@@ -118,7 +115,7 @@ struct Params {
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
@@ -137,6 +134,7 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
const int chunk_size = mainloop_params.chunk_size_device[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
@@ -205,58 +203,10 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -309,76 +259,12 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -429,15 +315,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
@@ -460,12 +346,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
@@ -476,7 +362,6 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
@@ -500,13 +385,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
|
||||
// by the graph.
|
||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
@@ -517,37 +398,38 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
merge_multi_chunks_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
KernelTraits::HEAD_DIM_VO>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
params.chunk_size_device,
|
||||
reinterpret_cast<NV_TYPE *>(params.O),
|
||||
params.q_num_head,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
|
Reference in New Issue
Block a user