mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] Add speculative decoding simulation benchmark. (#2751)
* Add speculative decoding simulation benchmark * Fix the name of the parameter
This commit is contained in:
@@ -73,7 +73,7 @@ __global__ void speculate_verify(
|
||||
const int *output_cum_offsets, const int *actual_candidate_len,
|
||||
const int real_bsz, const int max_draft_tokens, const int end_length,
|
||||
const int max_seq_len, const int max_candidate_len, const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop, const bool benchmark_mode) {
|
||||
const int bid = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
int accept_num_now = 1;
|
||||
@@ -95,6 +95,9 @@ __global__ void speculate_verify(
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
// seq_lens_this_time[bid]-1);
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
if (benchmark_mode) {
|
||||
break;
|
||||
}
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
break;
|
||||
}
|
||||
@@ -246,7 +249,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp) {
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
|
||||
// printf("Enter speculate update\n");
|
||||
auto bsz = accept_tokens.shape()[0];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
@@ -301,7 +304,7 @@ void SpeculateVerify(
|
||||
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
|
||||
end_length, max_seq_len, max_candidate_len, verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop, benchmark_mode);
|
||||
} else {
|
||||
speculate_verify<false, true>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
@@ -317,7 +320,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
}
|
||||
} else {
|
||||
if (enable_topp) {
|
||||
@@ -335,7 +338,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
} else {
|
||||
speculate_verify<false, false>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
@@ -351,7 +354,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,7 +369,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
||||
"actual_candidate_len", "actual_draft_token_nums", "topp"})
|
||||
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
|
||||
"stop_flags_out"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
|
Reference in New Issue
Block a user