mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
@@ -249,18 +249,16 @@ struct prefill_softmax_state_t {
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ batch_id_per_token,
|
||||
const int * __restrict__ chunk_size_device,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
@@ -271,13 +269,15 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
|
||||
if (bid == -1) continue;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
|
Reference in New Issue
Block a user