Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -24,8 +24,6 @@ __global__ void process_splitwise_prefill(
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
@@ -51,21 +49,18 @@ __global__ void process_splitwise_prefill(
int base_model_step_idx_now = base_model_step_idx[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
if (seq_lens_encoder[tid] > 0) {
not_stop_flag = 1;
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else {
stop_flags[tid] = true;
@@ -95,8 +90,6 @@ __global__ void draft_model_preprocess_kernel(
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
@@ -131,11 +124,7 @@ __global__ void draft_model_preprocess_kernel(
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
}
// 处理 base_model recover 逻辑
// 1. 已处于 recover 状态
// if (batch_drop[tid]) {
// }
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
@@ -147,22 +136,18 @@ __global__ void draft_model_preprocess_kernel(
if (base_model_step_idx_now == 0) {
seq_lens_this_time[tid] = 0;
not_stop_flag = 0;
} else if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
} else if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
seq_lens_decoder_record[tid] = 0;
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else if (accept_num_now <=
max_draft_token) /*Accept partial draft tokens*/ {
@@ -207,8 +192,6 @@ void DispatchRunner(
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
@@ -237,8 +220,6 @@ void DispatchRunner(
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
@@ -265,8 +246,6 @@ void DispatchRunner(
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
@@ -295,8 +274,6 @@ void DispatchTokenMode(
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
@@ -325,8 +302,6 @@ void DispatchTokenMode(
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
@@ -355,8 +330,6 @@ void DispatchTokenMode(
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
@@ -388,8 +361,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
@@ -422,8 +393,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(seq_lens_encoder_record.data<int>()),
const_cast<int*>(seq_lens_decoder_record.data<int>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
accept_tokens.data<int64_t>(),
@@ -443,10 +412,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
truncate_first_token,
splitwise_prefill);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
@@ -462,8 +427,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"seq_lens_encoder_record",
"seq_lens_decoder_record",
"not_need_stop",
"batch_drop",
"accept_tokens",
@@ -482,9 +445,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder_out",
"step_idx_out",
"not_need_stop_out",
"batch_drop_out",
"seq_lens_encoder_record_out",
"seq_lens_decoder_record_out"})
"batch_drop_out"})
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
@@ -494,7 +455,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"},
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
{"batch_drop", "batch_drop_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

View File

@@ -23,7 +23,7 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define MAX_BSZ 256
// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
@@ -57,7 +57,7 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -45,7 +45,7 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
const int real_batch_size,
int64_t max_batch_size,
int max_ngram_size = 3,
int max_draft_tokens = 10) {
int threshold = 128;
@@ -53,13 +53,13 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
if (env_var) {
threshold = std::stoi(env_var);
}
bool is_insert = false;
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
if (seq_lens_encoder[batch_idx] > 0) {
is_insert = true;
int unprocessed_batch_size = 0;
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
if (seq_lens_encoder[batch_idx] > 0 || seq_lens_decoder[batch_idx] > 0) {
unprocessed_batch_size++;
}
}
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
max_draft_tokens = std::min(static_cast<int64_t>(
draft_token_num[batch_idx]), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
if (seq_lens_encoder[batch_idx] > 0) {
@@ -68,26 +68,27 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
seq_lens_this_time[batch_idx] = 0;
continue;
}
// printf("bid: %d. enc: %d. dec. %d\n", batch_idx, seq_lens_encoder[batch_idx], seq_lens_decoder[batch_idx]);
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
seq_lens_this_time[batch_idx] = 1;
if (!is_insert) {
auto sum_token_num = sum(seq_lens_this_time, batch_idx);
int left_min_token_num = real_batch_size - batch_idx;
unprocessed_batch_size--;
if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens ? tmp_max_draft_tokens : max_draft_tokens;
}
auto sum_token_num = sum(seq_lens_this_time, batch_idx);
int left_min_token_num = unprocessed_batch_size;
if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}
if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens ? tmp_max_draft_tokens : max_draft_tokens;
}
if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}
for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
// Extract the last n tokens as our search ngram
@@ -164,7 +165,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int real_batch_size,
const int max_ngram_size,
const int max_draft_tokens) {
@@ -177,6 +177,8 @@ void NgramMatch(const paddle::Tensor &input_ids,
auto draft_tokens_shape = draft_tokens.shape();
const int64_t draft_tokens_stride = draft_tokens_shape[1];
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
find_candidate_pred_tokens(input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
@@ -190,7 +192,7 @@ void NgramMatch(const paddle::Tensor &input_ids,
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
real_batch_size,
max_batch_size,
max_ngram_size,
max_draft_tokens);
}
@@ -206,7 +208,7 @@ PD_BUILD_STATIC_OP(ngram_match)
"seq_lens_encoder",
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"real_batch_size: int", "max_ngram_size: int", "max_draft_tokens: int"})
.Attrs({"max_ngram_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(NgramMatch))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});

View File

@@ -23,7 +23,7 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {

View File

@@ -23,7 +23,7 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <curand_kernel.h>
#include "helper.h" // NOLINT
#include <cstdlib>
#include <curand_kernel.h>
#include <string>
#include "helper.h" // NOLINT
__device__ inline bool is_in(const int64_t *candidates,
const int64_t draft,
__device__ inline bool is_in(const int64_t *candidates, const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
return false;
}
return false;
}
static uint64_t seed = 0;
@@ -36,439 +35,336 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
curandState_t *dev_curand_states,
const int candidate_len,
const float topp) {
const int tid = threadIdx.x;
const int tid = threadIdx.x;
float sum_scores = 0.0f;
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
float sum_scores = 0.0f;
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
return candidate_ids[0];
}
return candidate_ids[0];
}
__global__ void setup_kernel(curandState_t *state,
const uint64_t seed,
const uint64_t offset,
const int bs,
__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
const uint64_t offset, const int bs,
const bool need_batch_random) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
if (need_batch_random) {
curand_init(seed, i, offset, &state[i]);
} else {
curand_init(seed, 0, offset, &state[i]);
}
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
if (need_batch_random) {
curand_init(seed, i, offset, &state[i]);
} else {
curand_init(seed, 0, offset, &state[i]);
}
}
}
template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify(int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
curandState_t *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
const int bid = threadIdx.x;
__global__ void speculate_verify(
int64_t *accept_tokens, int *accept_num, int64_t *step_idx,
bool *stop_flags, const int *seq_lens_encoder, const int *seq_lens_decoder,
const int64_t *draft_tokens, const int *actual_draft_token_nums,
curandState_t *dev_curand_states, const float *topp,
const int *seq_lens_this_time, const int64_t *verify_tokens,
const float *verify_scores, const int64_t *max_dec_len,
const int64_t *end_tokens, const bool *is_block_step,
const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop) {
const int bid = threadIdx.x;
// verify and set stop flags
int accept_num_now = 1;
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
// verify and set stop flags
int accept_num_now = 1;
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
// printf("bid %d\n", bid);
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入但是因为draft
// tokens会置零因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入但是因为draft
// tokens会置零因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now =
actual_candidate_len + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (seq_lens_encoder[bid] != 0) {
break;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
// printf("[USE_TOPK] bid %d Top 1 verify write accept
// %d is %lld\n", bid, i, accept_token);
accept_tokens[bid * max_draft_tokens + i] =
accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
// printf("[USE_TOPK] bid %d Top 1 verify write
// accept %d is %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
break;
}
} else {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1],
actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] =
accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
// printf("bid %d Top P verify write accept %d is
// %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
// TopK verify
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window &&
ii < seq_lens_this_time[bid] - 1;
j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
draft_tokens_now[ii + 1]) {
break;
}
}
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] =
accept_token;
// printf(
// "bid %d TopK verify write accept %d
// is "
// "%lld\n",
// bid,
// i,
// accept_token);
if (is_in_end(accept_token,
end_tokens,
end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid *
max_draft_tokens +
i] = end_tokens[0];
// printf("bid %d TopK verify write
// accept %d is %lld\n", bid, i,
// end_tokens[0]);
accept_num_now--;
step_idx[bid]--;
break;
}
}
}
}
break;
}
}
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (seq_lens_encoder[bid] != 0) {
break;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
// printf("[USE_TOPK] bid %d Top 1 verify write accept
// %d is %lld\n", bid, i, accept_token);
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("[USE_TOPK] bid %d Top 1 verify write
// accept %d is %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
// sampling阶段
// 第一种draft_token[i+1]被拒绝需要从verify_tokens_now[i]中选一个
// 第二种i == seq_lens_this_time[bid]-1,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
} else {
break;
}
} else {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1], actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
accept_token = topp_sampling_kernel(
verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len,
dev_curand_states,
actual_candidate_len_value,
topp[bid]);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d Top P verify write accept %d is
// %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
// TopK verify
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_lens_this_time[bid] - 1;
j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
draft_tokens_now[ii + 1]) {
break;
}
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
}
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
}
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
// printf(
// "bid %d TopK verify write accept %d
// is "
// "%lld\n",
// bid,
// i,
// accept_token);
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d TopK verify write
// accept %d is %lld\n", bid, i,
// end_tokens[0]);
accept_num_now--;
step_idx[bid]--;
break;
}
}
}
}
accept_num[bid] = accept_num_now;
break;
}
}
}
// sampling阶段
// 第一种draft_token[i+1]被拒绝需要从verify_tokens_now[i]中选一个
// 第二种i == seq_lens_this_time[bid]-1,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
accept_token = topp_sampling_kernel(
verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len, dev_curand_states,
actual_candidate_len_value, topp[bid]);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
}
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
}
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
}
}
accept_num[bid] = accept_num_now;
}
}
}
void SpeculateVerify(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens,
const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
int max_seq_len,
int verify_window,
bool enable_topp) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
auto end_length = end_tokens.shape()[0];
auto max_candidate_len = verify_tokens.shape()[1];
void SpeculateVerify(
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
auto end_length = end_tokens.shape()[0];
auto max_candidate_len = verify_tokens.shape()[1];
constexpr int BlockSize = 512;
constexpr int BlockSize = 512;
curandState_t *dev_curand_states;
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
dev_curand_states, seed, offset, bsz, true);
seed++;
offset++;
curandState_t *dev_curand_states;
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
dev_curand_states, seed, offset, bsz, true);
seed++;
offset++;
auto err = cudaDeviceSynchronize();
if (err != 0) {
printf("err %d\n", err);
auto err = cudaDeviceSynchronize();
if (err != 0) {
printf("err %d\n", err);
}
err = cudaGetLastError();
if (err != 0) {
printf("err %d\n", err);
}
// printf("inited curand\n");
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
err = cudaGetLastError();
if (err != 0) {
printf("err %d\n", err);
}
// printf("inited curand\n");
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
if (use_topk) {
// printf("use_topk \n");
if (enable_topp) {
speculate_verify<true, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
}
if (use_topk) {
if (enable_topp) {
speculate_verify<true, true><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(), seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(), verify_scores.data<float>(),
max_dec_len.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop);
} else {
if (enable_topp) {
speculate_verify<true, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop);
}
} else {
if (enable_topp) {
speculate_verify<true, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop);
}
}
cudaFree(dev_curand_states);
cudaFree(dev_curand_states);
}
PD_BUILD_STATIC_OP(speculate_verify)
.Inputs({"accept_tokens",
"accept_num",
"step_idx",
"seq_lens_encoder",
"seq_lens_decoder",
"stop_flags",
"draft_tokens",
"seq_lens_this_time",
"verify_tokens",
"verify_scores",
"max_dec_len",
"end_tokens",
"is_block_step",
"output_cum_offsets",
"actual_candidate_len",
"actual_draft_token_nums",
"topp"})
.Outputs({"accept_tokens_out",
"accept_num_out",
"step_idx_out",
.Inputs({"accept_tokens", "accept_num", "step_idx", "seq_lens_encoder",
"seq_lens_decoder", "stop_flags", "draft_tokens",
"seq_lens_this_time", "verify_tokens", "verify_scores",
"max_dec_len", "end_tokens", "is_block_step", "output_cum_offsets",
"actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},