[Executor]CUDAGraph support Speculate Decode (#3769)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* success run ngram

* Revert "[Code Simplification] remove cum_offsets (#3410)"

This reverts commit 32b39620bc.

* success run ngram5 tp4 42bs

* success run ngram5 tp4 42bs

* mtp draft commit

* add decorator for target model

* enable draft model in cudagraph v0.5

* revert revrt cum_offset

* enable target model in cudagraph v0.9 And clean debug code

* Revert "success run ngram"

This reverts commit 8351e83993.

* add reverted code

* enable target model in cudagraph v0.9

* solve comment

* fix bid < 0

* Enable Target Model Padding And Draft Model in cudagraph

* solve problem

* delete rebuild padding debug note

* fast compile

* Add capture list for mtp

* success run 256 tp1 mtp

* Enable Lite TP2 Bsz256

* realy enable tp2 bsz 256

* fix problem

* Solve problem for Draft model in cudagraph

* Solve comment

* replace emptytensor as zeros

* Solve comments

* Revert "fast compile"

This reverts commit 834639a7ff.

* fix bug

* fix merge bug

* fix typo

* fix bug

---------

Co-authored-by: lizexu <2694294196@qq.com>
Co-authored-by: littledgg <1658565283@qq.com>
Co-authored-by: zeroRains <linjunlu@zerorains.top>
Co-authored-by: gongshaotian <gstain5555@outlook.com>
This commit is contained in:
RAM
2025-10-09 21:18:29 +08:00
committed by GitHub
parent 7b1689f437
commit aa27b03bc0
19 changed files with 250 additions and 139 deletions

View File

@@ -494,12 +494,12 @@ std::vector<paddle::Tensor> AppendAttention(
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
@@ -507,7 +507,7 @@ std::vector<paddle::Tensor> AppendAttention(
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
dtype_id,
qkv.place());

View File

@@ -2418,6 +2418,9 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
@@ -2437,6 +2440,8 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}else if (!ENABLE_PREFILL){
continue;
}
using LoadT = AlignedVector<T, vec_size>;

View File

@@ -84,15 +84,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;
@@ -390,15 +382,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;
@@ -525,15 +509,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;

View File

@@ -684,7 +684,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts);
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,

View File

@@ -130,7 +130,6 @@ std::vector<paddle::Tensor> rebuild_padding(
int pack_num = elem_nums / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;
if (output_padding_offset) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(

View File

@@ -139,7 +139,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets"
"cum_offsets",
"token_num",
"seq_len",
"seq_lens_encoder"})

View File

@@ -73,7 +73,7 @@ __global__ void speculate_verify(
const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop, const bool benchmark_mode) {
const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) {
const int bid = threadIdx.x;
// verify and set stop flags
int accept_num_now = 1;
@@ -101,6 +101,24 @@ __global__ void speculate_verify(
if (seq_lens_encoder[bid] != 0) {
break;
}
if (accept_all_drafts) {
// accept all draft tokens
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
break;
} else {
accept_num_now++;
}
continue;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
@@ -249,7 +267,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
@@ -292,7 +310,7 @@ void SpeculateVerify(
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop, benchmark_mode);
prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -308,7 +326,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
}
} else {
if (enable_topp) {
@@ -326,7 +344,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -342,7 +360,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
}
}
@@ -357,7 +375,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
"actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},