Files
FastDeploy/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu
freeliuzc 7cdd8d290d
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
[MTP] optimize mtp infer speed (#2840)
2025-07-14 19:50:22 +08:00

366 lines
16 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
#include <cstdlib>
#include <curand_kernel.h>
#include <string>
__device__ inline bool is_in(const int64_t *candidates, const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
}
return false;
}
static uint64_t seed = 0;
static uint64_t offset = 0;
__device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
curandState_t *dev_curand_states,
const int candidate_len,
const float topp) {
const int tid = threadIdx.x;
float sum_scores = 0.0f;
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
const uint64_t offset, const int bs,
const bool need_batch_random) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
if (need_batch_random) {
curand_init(seed, i, offset, &state[i]);
} else {
curand_init(seed, 0, offset, &state[i]);
}
}
}
template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify(
int64_t *accept_tokens, int *accept_num, int64_t *step_idx,
bool *stop_flags, const int *seq_lens_encoder, const int *seq_lens_decoder,
const int64_t *draft_tokens, const int *actual_draft_token_nums,
curandState_t *dev_curand_states, const float *topp,
const int *seq_lens_this_time, const int64_t *verify_tokens,
const float *verify_scores, const int64_t *max_dec_len,
const int64_t *end_tokens, const bool *is_block_step,
const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop, const bool benchmark_mode) {
const int bid = threadIdx.x;
// verify and set stop flags
int accept_num_now = 1;
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入但是因为draft
// tokens会置零因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (benchmark_mode) {
break;
}
if (seq_lens_encoder[bid] != 0) {
break;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
// printf("[USE_TOPK] bid %d Top 1 verify write accept
// %d is %lld\n", bid, i, accept_token);
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("[USE_TOPK] bid %d Top 1 verify write
// accept %d is %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
break;
}
} else {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1], actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d Top P verify write accept %d is
// %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
// TopK verify
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_lens_this_time[bid] - 1;
j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
draft_tokens_now[ii + 1]) {
break;
}
}
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
// printf(
// "bid %d TopK verify write accept %d
// is "
// "%lld\n",
// bid,
// i,
// accept_token);
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d TopK verify write
// accept %d is %lld\n", bid, i,
// end_tokens[0]);
accept_num_now--;
step_idx[bid]--;
break;
}
}
}
}
break;
}
}
}
// sampling阶段
// 第一种draft_token[i+1]被拒绝需要从verify_tokens_now[i]中选一个
// 第二种i == seq_lens_this_time[bid]-1,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
accept_token = topp_sampling_kernel(
verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len, dev_curand_states,
actual_candidate_len_value, topp[bid]);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
}
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
}
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
}
}
accept_num[bid] = accept_num_now;
}
}
}
void SpeculateVerify(
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
auto end_length = end_tokens.shape()[0];
auto max_candidate_len = verify_tokens.shape()[1];
constexpr int BlockSize = 512;
curandState_t *dev_curand_states;
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
dev_curand_states, seed, offset, bsz, true);
seed++;
offset++;
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
if (use_topk) {
if (enable_topp) {
speculate_verify<true, true><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(), seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(), verify_scores.data<float>(),
max_dec_len.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop, benchmark_mode);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
}
} else {
if (enable_topp) {
speculate_verify<true, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
}
}
cudaFree(dev_curand_states);
}
PD_BUILD_STATIC_OP(speculate_verify)
.Inputs({"accept_tokens", "accept_num", "step_idx", "seq_lens_encoder",
"seq_lens_decoder", "stop_flags", "draft_tokens",
"seq_lens_this_time", "verify_tokens", "verify_scores",
"max_dec_len", "end_tokens", "is_block_step", "output_cum_offsets",
"actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},
{"stop_flags", "stop_flags_out"}})
.SetKernelFn(PD_KERNEL(SpeculateVerify));