[LLM] First commit the llm deployment code

This commit is contained in:
jiangjiajun
2025-06-09 19:20:15 +08:00
parent 980c0a1d2c
commit 684703fd72
11814 changed files with 127294 additions and 1293102 deletions

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void draft_model_update_seq_lens_this_time_kernel(
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
int tid = threadIdx.x;
if (tid < bsz) {
if (!base_model_stop_flags[tid] &&
base_model_seq_lens_encoder[tid] == 0) {
const int64_t* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_token_len;
int token_num = 0;
for (int i = 0; i < base_model_draft_token_len; ++i) {
if (base_model_draft_tokens_now[i] != -1) {
token_num++;
}
}
base_model_seq_lens_this_time[tid] = token_num;
} else if (base_model_stop_flags[tid]) {
base_model_seq_lens_this_time[tid] = 0;
}
}
}
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags) {
int real_bsz = base_model_seq_lens_this_time.shape()[0];
auto cu_stream = base_model_seq_lens_this_time.stream();
int block_size = (real_bsz + 32 - 1) / 32 * 32;
int base_model_draft_token_len = base_model_draft_tokens.shape()[1];
draft_model_update_seq_lens_this_time_kernel<<<1,
block_size,
0,
cu_stream>>>(
base_model_draft_tokens.data<int64_t>(),
const_cast<int*>(base_model_seq_lens_this_time.data<int>()),
base_model_seq_lens_encoder.data<int>(),
base_model_stop_flags.data<bool>(),
real_bsz,
base_model_draft_token_len);
}
PD_BUILD_STATIC_OP(draft_model_postprocess)
.Inputs({"base_model_draft_tokens",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",
"base_model_stop_flags"})
.Outputs({"base_model_draft_tokens_out",
"base_model_seq_lens_this_time_out",
"base_model_stop_flags_out"})
.SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"},
{"base_model_seq_lens_this_time",
"base_model_seq_lens_this_time_out"},
{"base_model_stop_flags", "base_model_stop_flags_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPostprocess));

View File

@@ -0,0 +1,500 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
__global__ void process_splitwise_prefill(
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
int tid = threadIdx.x;
if (tid < bsz) {
int base_model_step_idx_now = base_model_step_idx[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
not_stop_flag = 1;
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
not_stop_flag = 0;
}
}
__syncthreads();
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
if (tid == 0) {
not_need_stop[0] = not_stop_flag_sum > 0;
}
}
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
__global__ void draft_model_preprocess_kernel(
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
int tid = threadIdx.x;
if (tid < bsz) {
auto base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
auto accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len;
#pragma unroll
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
}
// 处理 base_model recover 逻辑
// 1. 已处于 recover 状态
// if (batch_drop[tid]) {
// }
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
if (base_model_step_idx_now == 0) {
seq_lens_this_time[tid] = 0;
not_stop_flag = 0;
} else if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
seq_lens_decoder_record[tid] = 0;
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
}
} else if (accept_num_now <=
max_draft_token) /*Accept partial draft tokens*/ {
// Base Model reject stop
if (stop_flags[tid]) {
stop_flags[tid] = false;
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
step_idx[tid] = base_model_step_idx[tid];
} else {
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
step_idx[tid] -= max_draft_token - accept_num_now;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
draft_tokens_now[0] = modified_token;
seq_lens_this_time[tid] = 1;
} else /*Accept all draft tokens*/ {
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
seq_lens_this_time[tid] = 2;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
}
}
__syncthreads();
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
if (tid == 0) {
not_need_stop[0] = not_stop_flag_sum > 0;
}
}
template <bool TRCUNCATE_FIRST_TOKEN>
void DispatchRunner(
const cudaStream_t& stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const bool splitwise_prefill) {
constexpr int BlockSize = 512;
if (splitwise_prefill) {
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len);
}
}
void DispatchTokenMode(
const cudaStream_t &stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const bool truncate_first_token,
const bool splitwise_prefill) {
if (truncate_first_token) {
DispatchRunner<true>(
stream,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
splitwise_prefill
);
} else {
DispatchRunner<false>(
stream,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
splitwise_prefill
);
}
}
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill) {
int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
int draft_tokens_len = draft_tokens.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
constexpr int BlockSize = 512;
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
DispatchTokenMode(
cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(seq_lens_encoder_record.data<int>()),
const_cast<int*>(seq_lens_decoder_record.data<int>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
base_model_stop_flags.data<bool>(),
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
truncate_first_token,
splitwise_prefill);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_preprocess)
.Inputs({"draft_tokens",
"input_ids",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"seq_lens_encoder_record",
"seq_lens_decoder_record",
"not_need_stop",
"batch_drop",
"accept_tokens",
"accept_num",
"base_model_seq_lens_encoder",
"base_model_seq_lens_decoder",
"base_model_step_idx",
"base_model_stop_flags",
"base_model_is_block_step",
"base_model_draft_tokens"})
.Outputs({"draft_tokens_out",
"input_ids_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_idx_out",
"not_need_stop_out",
"batch_drop_out",
"seq_lens_encoder_record_out",
"seq_lens_decoder_record_out"})
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"},
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void update_pre_ids_kernel(const int64_t* draft_tokens,
int64_t* pre_ids_all,
const bool* stop_flags,
int* seq_lens_this_time,
const int64_t* step_idx,
int bs,
int pre_id_length,
int max_draft_token) {
int tid = threadIdx.x;
if (tid < bs && seq_lens_this_time[tid] != 0 && !stop_flags[tid]) {
int64_t* pre_ids_all_now = pre_ids_all + tid * pre_id_length;
const int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
const int seq_len_this_time = seq_lens_this_time[tid];
if (step_idx[tid] - 1 > 0 /*Decoder Step*/) {
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_all_now[step_idx[tid] - i] =
draft_token_now[seq_len_this_time - 1 - i];
}
} else if (step_idx[tid] == 1 /*Encoder Step*/) {
pre_ids_all_now[1] = draft_token_now[0];
}
seq_lens_this_time[tid] = 1;
}
}
void SpeculateDraftModelUpdate(const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids_all,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx) {
int64_t real_bs = seq_lens_this_time.shape()[0];
int64_t pre_id_length = pre_ids_all.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
int64_t max_draft_token = draft_tokens.shape()[1];
int block_size = (real_bs + 32 - 1) / 32 * 32;
update_pre_ids_kernel<<<1, block_size, 0, cu_stream>>>(
draft_tokens.data<int64_t>(),
const_cast<int64_t*>(pre_ids_all.data<int64_t>()),
stop_flags.data<bool>(),
const_cast<int*>(seq_lens_this_time.data<int>()),
step_idx.data<int64_t>(),
real_bs,
pre_id_length,
max_draft_token);
}
PD_BUILD_STATIC_OP(draft_model_set_value_by_flags)
.Inputs({"draft_tokens",
"pre_ids_all",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.SetKernelFn(PD_KERNEL(SpeculateDraftModelUpdate));

View File

@@ -0,0 +1,215 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <int THREADBLOCK_SIZE>
__global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t stop_flag_now_int = 0;
int tid = threadIdx.x;
if (tid < bsz) {
auto* draft_token_now = draft_tokens + tid * max_draft_token;
auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id =
tid * max_seq_len - output_cum_offsets[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
auto seq_len_decoder = seq_lens_decoder[tid];
// 1. update step_idx && seq_lens_dec
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
int64_t token_this_time = -1;
// decoder step
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
}
step_idx[tid] += seq_len_this_time;
} else {
token_this_time = next_tokens_start[0];
// seq_lens_decoder[tid] = seq_lens_encoder[tid];
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
seq_lens_encoder[tid] = 0;
pre_ids_now[1] = token_this_time;
step_idx[tid] += 1;
draft_token_now[0] = token_this_time;
base_model_draft_tokens_now[substep + 1] = token_this_time;
}
// multi_end
if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) {
stop_flags[tid] = true;
stop_flag_now_int = 1;
// max_dec_len
} else if (step_idx[tid] >= max_dec_len[tid]) {
stop_flags[tid] = true;
draft_token_now[seq_len_this_time - 1] = end_ids[0];
base_model_draft_tokens_now[substep + 1] = end_ids[0];
stop_flag_now_int = 1;
}
} else {
draft_token_now[0] = -1;
base_model_draft_tokens_now[substep + 1] = -1;
stop_flag_now_int = 1;
}
// 2. set end
if (!stop_flags[tid]) {
seq_lens_this_time[tid] = 1;
} else {
seq_lens_this_time[tid] = 0;
seq_lens_encoder[tid] = 0;
}
}
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (tid == 0) {
not_need_stop[0] = stop_sum < bsz;
}
}
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep) {
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
auto cu_stream = seq_lens_this_time.stream();
const int real_bsz = seq_lens_this_time_shape[0];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
const int end_ids_len = end_ids.shape()[0];
const int max_draft_token = draft_tokens.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
const int max_base_model_draft_token = base_model_draft_tokens.shape()[1];
constexpr int BlockSize = 512;
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
draft_model_update_kernel<BlockSize><<<1, BlockSize, 0, cu_stream>>>(
inter_next_tokens.data<int64_t>(),
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
output_cum_offsets.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()),
not_need_stop_gpu.data<bool>(),
max_dec_len.data<int64_t>(),
end_ids.data<int64_t>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_update)
.Inputs({"inter_next_tokens",
"draft_tokens",
"pre_ids",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"output_cum_offsets",
"stop_flags",
"not_need_stop",
"max_dec_len",
"end_ids",
"base_model_draft_tokens"})
.Attrs({"max_seq_len: int", "substep: int"})
.Outputs({"draft_tokens_out",
"pre_ids_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_idx_out",
"stop_flags_out",
"not_need_stop_out",
"base_model_draft_tokens_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"pre_ids", "pre_ids_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"stop_flags", "stop_flags_out"},
{"not_need_stop", "not_need_stop_out"},
{"base_model_draft_tokens", "base_model_draft_tokens_out"}})
.SetKernelFn(PD_KERNEL(DraftModelUpdate));

View File

@@ -0,0 +1,240 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
// #define DEBUG_EAGLE_KERNEL
__global__ void ComputeOrderKernel(
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
int in_offset = 0; // input_offset(long)
int out_offset = 0; // output_offset(short)
if (threadIdx.x == 0) {
for (int i = 0; i < bsz; ++i) {
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
int cur_seq_lens_this_time = seq_lens_this_time[i];
int accept_num = accept_nums[i];
int cur_seq_lens_encoder = seq_lens_encoder[i];
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_base_model_seq_lens_this_time%d. cur_seq_lens_this_time%d, accept_num %d\n", i, cur_base_model_seq_lens_this_time, cur_seq_lens_this_time, accept_num);
#endif
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
#endif
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// 3. New end
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: base=0. draft !=0 \n", i);
#endif
in_offset += cur_base_model_seq_lens_this_time;
// 4. stopped
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
} else {
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num > actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
}
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", position_map[i]);
}
printf("\n");
#endif
}
}
template <typename T, int VecSize>
__global__ void rebuildHiddenStatesKernel(
const T* input,
const int* position_map,
T* out,
const int dim_embed,
const int elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt; elem_idx += blockDim.x * gridDim.x * VecSize) {
int ori_token_idx = elem_idx / dim_embed;
int token_idx = position_map[ori_token_idx];
if (token_idx >= 0) {
int offset = elem_idx % dim_embed;
if (token_idx == 0) {
}
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto input_token_num = input.shape()[0];
// auto output_token_num = padding_offset.shape()[0];
auto dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto position_map = paddle::empty({input_token_num}, seq_lens_this_time.dtype(), input.place());
cudaMemsetAsync(position_map.data<int>(), 0xFF, input_token_num * sizeof(seq_lens_this_time.dtype()), seq_lens_this_time.stream());
auto output_token_num = paddle::empty({1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
ComputeOrderKernel<<<1,1>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
accept_nums.data<int>(),
position_map.data<int>(),
output_token_num.data<int>(),
bsz,
actual_draft_token_num,
input_token_num);
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::empty({output_token_num_cpu, dim_embed}, input.dtype(), input.place());
constexpr int packSize = VEC_16B / (sizeof(DataType_));
int elem_cnt = input_token_num * dim_embed;
assert(elem_cnt % packSize == 0);
int pack_num = elem_cnt / packSize;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
constexpr int thread_per_block = 128;
rebuildHiddenStatesKernel<DataType_, packSize><<<grid_size, thread_per_block>>>(
reinterpret_cast<const DataType_*>(input.data<data_t>()),
position_map.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
dim_embed,
elem_cnt);
return {out};
}
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num) {
switch (input.dtype()) {
case paddle::DataType::FLOAT16: {
return DispatchDtype<paddle::DataType::FLOAT16>(
input,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
stop_flags,
accept_nums,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
actual_draft_token_num);
}
case paddle::DataType::BFLOAT16: {
return DispatchDtype<paddle::DataType::BFLOAT16>(
input,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
stop_flags,
accept_nums,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
actual_draft_token_num);
}
default: {
PD_THROW("Not support this data type");
}
}
}
PD_BUILD_STATIC_OP(eagle_get_hidden_states)
.Inputs({"input",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"stop_flags",
"accept_nums",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder"})
.Attrs({"actual_draft_token_num: int"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));

View File

@@ -0,0 +1,190 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
// #define DEBUG_EAGLE_KERNEL
__global__ void computeOrderKernel(
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
int in_offset = 0;
int out_offset = 0;
if (threadIdx.x == 0) {
for (int i = 0; i < bsz; ++i) {
int cur_seq_lens_this_time = seq_lens_this_time[i];
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_seq_lens_this_time:%d. cur_last_seq_lens_this_time:%d\n", i, cur_seq_lens_this_time, cur_last_seq_lens_this_time);
#endif
// 1. encoder
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d last_step is encoder \n", i);
#endif
in_offset += 1;
src_map[out_offset++] = in_offset - 1;
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finish. src_map[%d]=%d \n", i, out_offset - 1, in_offset - 1);
#endif
// 2. decoder
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d is decoder\n", i);
#endif
in_offset += cur_last_seq_lens_this_time;
src_map[out_offset++] = in_offset - 1;
// 3. stop
} else {
// first token end
if (step_idx[i] == 1) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finished in first token \n", i);
#endif
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
// normal end
} else {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finished in non-first token \n", i);
#endif
in_offset += cur_last_seq_lens_this_time;
}
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", src_map[i]);
}
printf("\n");
#endif
}
}
template<typename T, int PackSize>
__global__ void rebuildSelfHiddenStatesKernel(
const T* input,
int* src_map,
T* output,
int dim_embed,
int elem_cnt) {
using LoadT = AlignedVector<T, PackSize>;
LoadT src_vec;
int global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int elem_id = global_thread_idx * PackSize; elem_id < elem_cnt; elem_id += blockDim.x * gridDim.x * PackSize) {
int output_token_idx = elem_id / dim_embed;
int input_token_idx = src_map[output_token_idx];
int offset = elem_id % dim_embed;
Load<T, PackSize>(input + input_token_idx * dim_embed + offset, &src_vec);
Store<T, PackSize>(src_vec, output + output_token_idx * dim_embed + offset);
}
}
template<paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor input,
const paddle::Tensor last_seq_lens_this_time,
const paddle::Tensor seq_lens_this_time,
const paddle::Tensor step_idx) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int input_token_num = input.shape()[0];
int dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto src_map = paddle::full({input_token_num}, -1, seq_lens_this_time.dtype(), seq_lens_this_time.place());
auto output_token_num = paddle::full({1}, 0, seq_lens_this_time.dtype(), seq_lens_this_time.place());
computeOrderKernel<<<1, 1, 0, seq_lens_this_time.stream()>>>(
last_seq_lens_this_time.data<int>(),
seq_lens_this_time.data<int>(),
step_idx.data<int64_t>(),
src_map.data<int>(),
output_token_num.data<int>(),
bsz);
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::full({output_token_num_cpu, dim_embed}, -1, input.type(), input.place());
constexpr int packSize = VEC_16B / (sizeof(DataType_));
int elem_cnt = output_token_num_cpu * dim_embed;
// printf("output_token_num: %d, dim_embed: %d, cnt: %d. packSize: %d\n", output_token_num_cpu, dim_embed,elem_cnt, packSize);
assert(elem_cnt % packSize == 0);
int pack_num = elem_cnt / packSize;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
constexpr int threadPerBlock = 128;
rebuildSelfHiddenStatesKernel<DataType_, packSize><<<grid_size, threadPerBlock, 0, input.stream()>>>(
reinterpret_cast<const DataType_*>(input.data<data_t>()),
src_map.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
dim_embed,
elem_cnt);
return {out};
}
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx) {
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
input,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx);
case paddle::DataType::FLOAT16:
return DispatchDtype<paddle::DataType::FLOAT16>(
input,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx);
default:
PD_THROW("Not support this data type");
}
}
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
.Inputs({"input",
"last_seq_lens_this_time",
"seq_lens_this_time",
"step_idx"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));

View File

@@ -0,0 +1,136 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, int VecSize>
__global__ void HydraFetchHiddenStatesKernel(const T* output_hidden_states,
const int* output_padding_offset,
const int* accept_token_num,
T* hidden_states,
const int bsz,
const int max_seq_len,
const int hidden_size) {
const int token_id = blockIdx.x;
const int ori_token_id = token_id + output_padding_offset[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_ori_token_id = bid * max_seq_len;
const int local_token_id = ori_token_id - start_ori_token_id;
if (local_token_id != accept_token_num[bid] - 1) return;
using LoadT = AlignedVector<T, VecSize>;
LoadT vec;
for (int idx = threadIdx.x * VecSize; idx < hidden_size;
idx += blockDim.x * VecSize) {
Load(&output_hidden_states[token_id * hidden_size + idx], &vec);
Store(vec, &hidden_states[bid * hidden_size + idx]);
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> HydraFetchHiddenStatesImpl(
const paddle::Tensor& output_hidden_states,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& accept_token_num,
const int bsz,
const int max_seq_length) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto cu_stream = output_hidden_states.stream();
auto output_token_num = output_hidden_states.shape()[0];
auto hidden_size = output_hidden_states.shape()[1];
auto hidden_states = paddle::full({bsz, hidden_size},
0,
output_hidden_states.dtype(),
output_hidden_states.place());
constexpr int VecSize = 16 / sizeof(data_t);
HydraFetchHiddenStatesKernel<data_t, VecSize>
<<<output_token_num, 256, 0, cu_stream>>>(
output_hidden_states.data<data_t>(),
output_padding_offset.data<int>(),
accept_token_num.data<int>(),
hidden_states.data<data_t>(),
bsz,
max_seq_length,
hidden_size);
return {hidden_states}; // , enc_token_num, dec_token_num};
}
std::vector<paddle::Tensor> HydraFetchHiddenStates(
const paddle::Tensor& output_hidden_states,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& accept_token_num,
const int bsz,
const int max_seq_length) {
switch (output_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16: {
return HydraFetchHiddenStatesImpl<paddle::DataType::BFLOAT16>(
output_hidden_states,
output_padding_offset,
accept_token_num,
bsz,
max_seq_length);
}
case paddle::DataType::FLOAT16: {
return HydraFetchHiddenStatesImpl<paddle::DataType::FLOAT16>(
output_hidden_states,
output_padding_offset,
accept_token_num,
bsz,
max_seq_length);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> HydraFetchHiddenStatesInferShape(
const std::vector<int64_t>& output_hidden_states_shape,
const std::vector<int64_t>& output_padding_offset_shape,
const std::vector<int64_t>& accept_token_num_shape,
const int bsz,
const int max_seq_length) {
return {{bsz, output_hidden_states_shape[1]}};
}
std::vector<paddle::DataType> HydraFetchHiddenStatesInferDtype(
const paddle::DataType& output_hidden_states_dtype,
const paddle::DataType& output_padding_offset_dtype,
const paddle::DataType& accept_token_num_dtype) {
return {output_hidden_states_dtype};
}
PD_BUILD_STATIC_OP(hydra_fetch_hidden_states)
.Inputs({"output_hidden_states",
"output_padding_offset",
"accept_token_num"})
.Outputs({"hidden_states"})
.Attrs({"bsz: int", "max_seq_length,: int"})
.SetKernelFn(PD_KERNEL(HydraFetchHiddenStates))
.SetInferShapeFn(PD_INFER_SHAPE(HydraFetchHiddenStatesInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(HydraFetchHiddenStatesInferDtype));

View File

@@ -0,0 +1,160 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype;
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
};
void MTPSaveFirstToken(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[2*i] << " ";
std::cout << " " << (int)x_data[2*i + 1];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
printf("full msg buffer\n");
}
return;
}
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
}
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_STATIC_OP(mtp_save_first_token)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t",
"save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));

View File

@@ -0,0 +1,226 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
template<int NUM_THREADS, int MAX_BATCH_SIZE=256>
__global__ void mtp_free_and_dispatch_block(
bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens) {
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, NUM_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ int need_block_len;
__shared__ int need_block_list[MAX_BATCH_SIZE];
const int tid = threadIdx.x;
if (tid < bsz) {
if (tid == 0) {
need_block_len = 0;
}
need_block_list[tid] = 0;
int *block_table_now = block_tables + tid * block_num_per_seq;
if (base_model_stop_flags[tid] || batch_drop[tid]) {
// 回收block块
const int encoder_block_len = encoder_block_lens[tid];
const int decoder_used_len = used_list_len[tid];
if (decoder_used_len > 0) {
const int ori_free_list_len =
atomicAdd(free_list_len, decoder_used_len);
#ifdef DEBUG_STEP
printf(
"free block seq_id: %d, free block num: %d, "
"encoder_block_len: %d, ori_free_list_len: %d\n",
tid,
decoder_used_len,
encoder_block_len,
ori_free_list_len);
#endif
for (int i = 0; i < decoder_used_len; i++) {
free_list[ori_free_list_len + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
encoder_block_lens[tid] = 0;
used_list_len[tid] = 0;
}
}
}
__syncthreads();
if (tid < bsz) {
int *block_table_now = block_tables + tid * block_num_per_seq;
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size;
if (!base_model_stop_flags[tid] && !batch_drop[tid] && max_possible_block_idx < block_num_per_seq &&
block_table_now[max_possible_block_idx] == -1) {
int ori_need_block_len = atomicAdd(&need_block_len, 1);
need_block_list[ori_need_block_len] = tid;
// 统计需要分配block的位置和总数
// const int ori_free_list_len = atomicSub(free_list_len, 1);
// block_table_now[(seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size] =
// free_list[ori_free_list_len - 1];
// used_list_len[tid] += 1;
#ifdef DEBUG_STEP
printf("seq_id: %d need block\n", tid);
#endif
}
}
__syncthreads();
#ifdef DEBUG_STEP
if (tid == 0) {
printf("need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
}
#endif
// 这里直接从 bid 0 开始遍历
while (need_block_len > free_list_len[0]) {
#ifdef DEBUG_STEP
if (tid == 0) {
printf("in while need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
}
#endif
const int used_block_num =
tid < bsz && !base_model_stop_flags[tid] ? used_list_len[tid] : 0;
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
if (tid == 0) {
const int encoder_block_len = encoder_block_lens[kv_pair.key];
int *block_table_now =
block_tables + kv_pair.key * block_num_per_seq;
for (int i = 0; i < kv_pair.value; i++) {
free_list[free_list_len[0] + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
const int ori_free_list_len = atomicAdd(free_list_len, kv_pair.value);
printf(
"MTP STEP need_block_len: %d. free_list_len: %d."
"Drop bid: %d, free block num: %d, "
"encoder_block_len: %d,"
"After drop free_list_len %d \n",
need_block_len,
ori_free_list_len,
kv_pair.key,
kv_pair.value,
encoder_block_len,
free_list_len[0]);
stop_flags[kv_pair.key] = true;
batch_drop[kv_pair.key] = true;
seq_lens_this_time[kv_pair.key] = 0;
seq_lens_decoder[kv_pair.key] = 0;
used_list_len[kv_pair.key] = 0;
}
__syncthreads();
}
if (tid < need_block_len) {
const int need_block_id = need_block_list[tid];
// 这里必须用 batch_drop, 不能用 stop_flags
if (!batch_drop[need_block_id]) {
used_list_len[need_block_id] += 1;
const int ori_free_list_len = atomicSub(free_list_len, 1);
int *block_table_now =
block_tables + need_block_id * block_num_per_seq;
#ifdef DEBUG_STEP
printf("bid: %d allocate block_id %d. seq_lens_decoder:%d \n", need_block_id, free_list[ori_free_list_len - 1], seq_lens_decoder[need_block_id]);
#endif
block_table_now[(seq_lens_decoder[need_block_id] +
max_draft_tokens + 1) /
block_size] = free_list[ori_free_list_len - 1];
}
}
}
void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags,
const paddle::Tensor &batch_drop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const int block_size,
const int max_draft_tokens) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
constexpr int BlockSize = 512; // bsz <= 256
#ifdef DEBUG_STEP
printf(
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
"%d\n",
bsz,
block_num_per_seq,
length,
max_decoder_block_num);
#endif
mtp_free_and_dispatch_block<BlockSize, BlockSize><<<1, BlockSize, 0, cu_stream>>>(
const_cast<bool *>(base_model_stop_flags.data<bool>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(batch_drop.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_draft_tokens);
}
PD_BUILD_STATIC_OP(mtp_step_paddle)
.Inputs({"base_model_stop_flags",
"stop_flags",
"batch_drop",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"used_list_len",
"free_list",
"free_list_len"})
.Attrs({"block_size: int",
"max_draft_tokens: int"})
.Outputs({"block_tables_out",
"stop_flags_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out"})
.SetInplaceMap({{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"}})
.SetKernelFn(PD_KERNEL(MTPStepPaddle));

View File

@@ -0,0 +1,212 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
int sum(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
}
return sum_value;
}
void find_candidate_pred_tokens(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_encoder,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
const int real_batch_size,
int max_ngram_size = 3,
int max_draft_tokens = 10) {
int threshold = 128;
char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
bool is_insert = false;
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
if (seq_lens_encoder[batch_idx] > 0) {
is_insert = true;
}
}
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
max_draft_tokens = std::min(static_cast<int64_t>(
draft_token_num[batch_idx]), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
if (seq_lens_encoder[batch_idx] > 0) {
continue;
} else if (seq_lens_decoder[batch_idx] == 0) {
seq_lens_this_time[batch_idx] = 0;
continue;
}
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
seq_lens_this_time[batch_idx] = 1;
if (!is_insert) {
auto sum_token_num = sum(seq_lens_this_time, batch_idx);
int left_min_token_num = real_batch_size - batch_idx;
if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens ? tmp_max_draft_tokens : max_draft_tokens;
}
if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}
}
for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
// Extract the last n tokens as our search ngram
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
// Iterate through sliding windows of size ngram_size
bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
// Check if the current window matches the ngram
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_input_ids[i + j]) {
match = false;
break;
}
}
if (match) {
int64_t start_idx = i + ngram_size;
int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_input_ids_len);
if (start_idx >= end_idx)
continue;
int64_t cur_draft_token_num = end_idx - start_idx;
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
ngram_size = 0;
match_input = true;
break;
// }
}
}
if (!match_input) {
for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) {
// Check if the current window matches the ngram
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match = false;
break;
}
}
if (match) {
int64_t start_idx = i + ngram_size;
int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_step_idx);
int64_t cur_draft_token_num = end_idx - start_idx;
if (start_idx >= end_idx)
continue;
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
ngram_size = 0;
break;
}
}
}
}
}
}
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int real_batch_size,
const int max_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
auto pre_ids_shape = pre_ids.shape();
const int64_t pre_ids_stride = pre_ids_shape[1];
auto draft_tokens_shape = draft_tokens.shape();
const int64_t draft_tokens_stride = draft_tokens_shape[1];
find_candidate_pred_tokens(input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
real_batch_size,
max_ngram_size,
max_draft_tokens);
}
PD_BUILD_STATIC_OP(ngram_match)
.Inputs({"input_ids",
"input_ids_len",
"pre_ids",
"step_idx",
"draft_token_num",
"draft_tokens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"real_batch_size: int", "max_ngram_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(NgramMatch))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void CalculateKernel(int32_t* sum_draft_num,
int32_t* sum_accept_num,
const int32_t* accept_nums,
const int32_t* seq_lens_this_time,
const int32_t* seq_lens_decoder,
const bool* stop_flags,
int real_bsz) {
int tid = threadIdx.x;
int draft_num = 0, accept_num = 0;
if (tid < real_bsz) {
if (seq_lens_decoder[tid] > 0 &&
seq_lens_this_time[tid] != seq_lens_decoder[tid]) {
draft_num = seq_lens_this_time[tid] - 1;
accept_num = accept_nums[tid] - 1;
} else if (seq_lens_this_time[tid] > 0 &&
stop_flags[tid]) { // last step
draft_num = seq_lens_this_time[tid] - 1;
accept_num = accept_nums[tid];
}
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int draft_nums_sum = BlockReduce(temp_storage).Sum(draft_num);
int accept_nums_sum = BlockReduce(temp_storage).Sum(accept_num);
if (tid == 0 && draft_nums_sum != 0) {
sum_draft_num[0] += draft_nums_sum;
sum_accept_num[0] += accept_nums_sum;
}
}
void Calculate(const paddle::Tensor& sum_draft_num,
const paddle::Tensor& sum_accept_num,
const paddle::Tensor& accept_nums,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags) {
int real_bsz = seq_lens_this_time.shape()[0];
constexpr int BLOCK_SIZE = 512;
CalculateKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, accept_nums.stream()>>>(
const_cast<int*>(sum_draft_num.data<int32_t>()),
const_cast<int*>(sum_accept_num.data<int32_t>()),
accept_nums.data<int32_t>(),
seq_lens_this_time.data<int32_t>(),
seq_lens_decoder.data<int32_t>(),
stop_flags.data<bool>(),
real_bsz);
}
PD_BUILD_STATIC_OP(speculate_calcu_accept_ratio)
.Inputs({"sum_draft_num",
"sum_accept_num",
"accept_nums",
"seq_lens_this_time",
"seq_lens_decoder",
"stop_flags"})
.Outputs({"sum_draft_num_out", "sum_accept_num_out"})
.SetInplaceMap({{"sum_draft_num", "sum_draft_num_out"},
{"sum_accept_num", "sum_accept_num_out"}})
.SetKernelFn(PD_KERNEL(Calculate));

View File

@@ -0,0 +1,39 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
__global__ void speculate_clear_accept_nums_kernel(int* accept_num,
const int* seq_lens_decoder,
const int max_bsz) {
const int bid = threadIdx.x;
if (bid >= max_bsz) return;
accept_num[bid] = seq_lens_decoder[bid] == 0 ? 0 : accept_num[bid];
}
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder) {
// printf("enter clear \n");
const int max_bsz = seq_lens_decoder.shape()[0];
speculate_clear_accept_nums_kernel<<<1, 1024, 0, accept_num.stream()>>>(
const_cast<int*>(accept_num.data<int>()),
seq_lens_decoder.data<int>(),
max_bsz);
}
PD_BUILD_STATIC_OP(speculate_clear_accept_nums)
.Inputs({"accept_num", "seq_lens_decoder"})
.Outputs({"seq_lens_decoder_out"})
.SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}})
.SetKernelFn(PD_KERNEL(SpeculateClearAcceptNums));

View File

@@ -0,0 +1,118 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define MAX_DRAFT_TOKENS 6
struct msgdata {
int64_t mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, accept_num*bsz, tokens...
};
void SpeculateGetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id,
bool get_each_rank) {
if (!get_each_rank && rank_id > 0) {
return;
}
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_rcv;
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
0,
IPC_NOWAIT);
} else {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
0,
0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
return;
}
void SpeculateGetOutputStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
bool get_each_rank) {
SpeculateGetOutput(x, rank_id, wait_flag, 1, get_each_rank);
}
void SpeculateGetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id,
bool get_each_rank) {
SpeculateGetOutput(x, rank_id, wait_flag, msg_queue_id, get_each_rank);
}
PD_BUILD_STATIC_OP(speculate_get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t", "wait_flag: bool", "get_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputStatic));
PD_BUILD_STATIC_OP(speculate_get_output_dynamic)
.Inputs({"x"})
.Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int", "get_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputDynamic));

View File

@@ -0,0 +1,88 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateGetOutputPaddingOffsetKernel(
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : output_cum_offsets_tmp[bi - 1];
for (int i = ti; i < seq_lens_output[bi]; i += blockDim.x) {
output_padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
if (ti == 0) {
output_cum_offsets[bi] = cum_offset;
}
}
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len) {
auto cu_stream = output_cum_offsets_tmp.stream();
std::vector<int64_t> output_cum_offsets_tmp_shape =
output_cum_offsets_tmp.shape();
const int bsz = output_cum_offsets_tmp_shape[0];
auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false);
auto output_padding_offset = paddle::full({cpu_out_token_num},
0,
paddle::DataType::INT32,
output_cum_offsets_tmp.place());
auto output_cum_offsets =
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
SpeculateGetOutputPaddingOffsetKernel<<<bsz, 256, 0, cu_stream>>>(
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
output_cum_offsets_tmp.data<int>(),
seq_lens_output.data<int>(),
max_seq_len);
return {output_padding_offset, output_cum_offsets};
}
std::vector<std::vector<int64_t>> SpeculateGetOutputPaddingOffsetInferShape(
const std::vector<int64_t>& output_cum_offsets_tmp_shape,
const std::vector<int64_t>& out_token_num_shape,
const std::vector<int64_t>& seq_lens_output_shape) {
int64_t bsz = output_cum_offsets_tmp_shape[0];
return {{-1}, {bsz}};
}
std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
const paddle::DataType& output_cum_offsets_tmp_dtype,
const paddle::DataType& out_token_num_dtype,
const paddle::DataType& seq_lens_output_dtype) {
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
}
PD_BUILD_STATIC_OP(speculate_get_output_padding_offset)
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
.Outputs({"output_padding_offset", "output_cum_offsets"})
.Attrs({"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype));

View File

@@ -0,0 +1,155 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateRemovePadding(int64_t* output_data,
const int64_t* input_data,
const int64_t* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets,
const int sequence_length,
const int max_draft_tokens) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
if (seq_lens_encoder[bi] > 0) {
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else {
const int src_seq_id = bi * max_draft_tokens + i;
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
}
}
}
__global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
}
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder) {
auto cu_stream = input_ids.stream();
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
const int max_draft_tokens = draft_tokens.shape()[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens);
return {x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& draft_tokens_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
const paddle::DataType& input_ids_dtype,
const paddle::DataType& draft_tokens_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
}
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets",
"token_num",
"seq_len",
"seq_lens_encoder"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype));

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateGetSeqLensOutputKernel(int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz) {
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
if (seq_lens_this_time[bid] == 0) {
continue;
} else if (seq_lens_this_time[bid] == 1) {
seq_lens_output[bid] = 1;
} else if (seq_lens_encoder[bid] != 0) {
seq_lens_output[bid] = 1;
} else {
seq_lens_output[bid] = seq_lens_this_time[bid];
}
}
}
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder) {
auto cu_stream = seq_lens_this_time.stream();
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
const int bsz = seq_lens_this_time_shape[0];
auto seq_lens_output = paddle::full(
{bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
SpeculateGetSeqLensOutputKernel<<<1, 256, 0, cu_stream>>>(
seq_lens_output.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
bsz);
return {seq_lens_output};
}
std::vector<std::vector<int64_t>> SpeculateGetSeqLensOutputInferShape(
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape) {
int64_t bsz = seq_lens_this_time_shape[0];
return {{bsz}};
}
std::vector<paddle::DataType> SpeculateGetSeqLensOutputInferDtype(
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype) {
return {seq_lens_this_time_dtype};
}
PD_BUILD_STATIC_OP(speculate_get_seq_lens_output)
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
.Outputs({"seq_lens_output"})
.SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetSeqLensOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetSeqLensOutputInferDtype));

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateHydraSetScoreThresholdKernel(
float* threshold,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* accept_num,
const int real_bsz,
const float default_threshold = 0.3,
const float upper_threshold = 0.8,
const float lower_threshold = 0.0,
const float threshold_step = 0.1,
const float threshold_step_fac = 0.5) {
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
if (seq_lens_encoder[bid] > 0) {
threshold[bid] = default_threshold;
} else if (seq_lens_this_time[bid] <= 1) {
continue;
} else if (accept_num[bid] >= seq_lens_this_time[bid] &&
threshold[bid] >
lower_threshold + threshold_step * threshold_step_fac) {
threshold[bid] -= threshold_step * threshold_step_fac;
} else if (accept_num[bid] < seq_lens_this_time[bid] &&
threshold[bid] < upper_threshold - threshold_step) {
threshold[bid] += threshold_step;
}
}
}
void SpeculateHydraSetScoreThreshold(const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num,
const paddle::Tensor& threshold) {
auto cu_stream = seq_lens_this_time.stream();
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
const int bsz = seq_lens_this_time_shape[0];
SpeculateHydraSetScoreThresholdKernel<<<1, 256, 0, cu_stream>>>(
const_cast<float*>(threshold.data<float>()),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
accept_num.data<int>(),
bsz);
}
PD_BUILD_STATIC_OP(speculate_hydra_set_score_threshold)
.Inputs(
{"seq_lens_this_time", "seq_lens_encoder", "accept_num", "threshold"})
.Outputs({"threshold_out"})
.SetInplaceMap({{"threshold", "threshold_out"}})
.SetKernelFn(PD_KERNEL(SpeculateHydraSetScoreThreshold));

View File

@@ -0,0 +1,68 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void hydra_update_this_time(int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const float* topk_scores,
const float* score_threshold,
int real_bsz,
int idx) {
int linear_idx = threadIdx.x;
// verify and set stop flags
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] != 0) {
if (topk_scores[linear_idx] > score_threshold[linear_idx] &&
seq_lens_this_time[linear_idx] == idx + 1) {
seq_lens_this_time[linear_idx]++;
}
} else if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] == 0) {
seq_lens_this_time[linear_idx] = 0;
}
}
}
void HydraUpdateThisTime(const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& topk_scores,
const paddle::Tensor& score_threshold,
const int real_bsz,
const int idx) {
constexpr int BlockSize = 512;
hydra_update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
topk_scores.data<float>(),
score_threshold.data<float>(),
real_bsz,
idx);
}
PD_BUILD_STATIC_OP(speculate_hydra_update_seqlens_this_time)
.Inputs({"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"topk_scores",
"score_threshold"})
.Outputs({"seq_lens_this_time_out"})
.Attrs({"real_bsz: int", "idx: int"})
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
.SetKernelFn(PD_KERNEL(HydraUpdateThisTime));

View File

@@ -0,0 +1,32 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
// TODO: replace all msgdata in speculate-decoding
struct speculate_msgdata {
long mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, tokens
};

View File

@@ -0,0 +1,149 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
template <typename T, int VecSize>
__global__ void RebuildAppendPaddingKernel(
T *out,
const T *full_hidden_states,
const int *cum_offset,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *output_padding_offset,
const int seq_len,
const int dim_embed,
const size_t elem_nums) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
const int out_token_id = i / dim_embed;
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
const int bi = ori_token_id / seq_len;
int seq_id = 0;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
else if (seq_len_encoder[bi] != 0) {
seq_id = seq_len_encoder[bi] - 1;
}
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
Store<T, VecSize>(src_vec, &out[i]);
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor& full_hidden_states,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& output_padding_offset,
const int max_seq_len) {
// src: [token_num, dim_embed]
// dst: [batch_size, 1, dim_embed]
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int dim_embed = full_hidden_states.shape()[1];
int output_token_num = output_padding_offset.shape()[0];
int elem_nums = output_token_num * dim_embed;
constexpr int PackSize = VEC_16B / sizeof(DataType_);
assert(elem_nums % PackSize == 0);
auto out = paddle::full({output_token_num, dim_embed}, 0, full_hidden_states.dtype(), full_hidden_states.place());
int pack_num = elem_nums / PackSize;
const int threads_per_block = 128;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
cum_offsets.data<int32_t>(),
seq_len_encoder.data<int32_t>(),
seq_len_decoder.data<int32_t>(),
output_padding_offset.data<int32_t>(),
max_seq_len,
dim_embed,
elem_nums);
return {out};
}
std::vector<paddle::Tensor> RebuildAppendPadding(
const paddle::Tensor& full_hidden_states,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& output_padding_offset,
const int max_seq_len) {
switch (full_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
case paddle::DataType::FLOAT16:
return DispatchDtype<paddle::DataType::FLOAT16>(
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
default:
PD_THROW("Unsupported data type.");
}
}
std::vector<std::vector<int64_t>> RebuildAppendPaddingInferShape(
const std::vector<int64_t>& full_hidden_states_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& seq_len_encoder_shape,
const std::vector<int64_t>& seq_len_decoder_shape,
const std::vector<int64_t>& output_padding_offset_shape) {
const int64_t output_token_num = output_padding_offset_shape[0];
const int64_t dim_embed = full_hidden_states_shape[1];
std::vector<int64_t> out_shape = {output_token_num, dim_embed};
return {out_shape};
}
std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
const paddle::DataType& full_hidden_states_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& seq_len_encoder_dtype,
const paddle::DataType& seq_len_decoder_dtype,
const paddle::DataType& output_padding_offset_dtype) {
return {full_hidden_states_dtype};
}
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
.Inputs({"full_hidden_states",
"cum_offsets",
"seq_len_encoder",
"seq_len_decoder",
"output_padding_offset"})
.Attrs({"max_seq_len: int"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));

View File

@@ -0,0 +1,163 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, tokens
};
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
int save_each_rank) {
// printf("enter save output");
if (!save_each_rank && rank_id > 0) {
return;
}
int max_draft_tokens = accept_tokens.shape()[1];
auto accept_tokens_cpu = accept_tokens.copy_to(paddle::CPUPlace(), true);
auto accept_num_cpu = accept_num.copy_to(paddle::CPUPlace(), true);
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
int* accept_num_data = accept_num_cpu.data<int>();
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_sed;
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = accept_tokens.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
}
}
for (int i = MAX_BSZ + 2; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2;
i++) {
int token_id = i - MAX_BSZ - 2;
int bid = token_id / MAX_DRAFT_TOKENS;
int local_token_id = token_id % MAX_DRAFT_TOKENS;
if (token_id / MAX_DRAFT_TOKENS >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] =
accept_tokens_data[bid * max_draft_tokens + local_token_id];
}
}
if ((msgsnd(msgid,
&msg_sed,
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
0)) == -1) {
printf("full msg buffer\n");
}
return;
}
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
}
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_STATIC_OP(speculate_save_output)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));

View File

@@ -0,0 +1,91 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
int tid = threadIdx.x;
if (tid < bs && !stop_flags[tid]) {
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
const int64_t *accept_tokens_now =
accept_tokens + tid * max_draft_tokens;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
// printf("step_idx[tid] %d\n", step_idx[tid]);
if (step_idx[tid] >= 0) {
for (int i = 0; i < accept_num[tid]; i++) {
pre_ids_all_now[step_idx[tid] - i] =
accept_tokens_now[accept_num[tid] - 1 - i];
// printf("pre_ids_all_now[step_idx[tid] - i] %d \n",
// pre_ids_all_now[step_idx[tid] - i]);
}
}
}
}
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx) {
// printf("enter set value \n");
auto cu_stream = stop_flags.stream();
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all_shape[1];
int max_draft_tokens = accept_tokens.shape()[1];
int block_size = (bs + 32 - 1) / 32 * 32;
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
step_idx.data<int64_t>(),
bs,
length,
max_draft_tokens);
}
PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
.Inputs({"pre_ids_all",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));

View File

@@ -0,0 +1,481 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
// #define DEBUG_STEP
__device__ bool speculate_free_and_dispatch_block(const int &qid,
int *need_block_list,
const int &need_block_len) {
bool res = false;
for (int i = 0; i < need_block_len; i++) {
if (qid == need_block_list[i]) {
res = true;
need_block_list[i] = -1;
break;
}
}
return res;
}
__global__ void speculate_free_and_dispatch_block(
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, 256> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ bool step_max_block_flag;
__shared__ int in_need_block_list_len;
const int tid = threadIdx.x;
if (tid < bsz) {
if (tid == 0) {
step_max_block_flag = false;
in_need_block_list_len = 0;
}
int *block_table_now = block_tables + tid * block_num_per_seq;
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1 ) / block_size;
if (stop_flags[tid] && !is_block_step[tid]) {
// 回收block块
first_token_ids[tid] = -1;
const int encoder_block_len = encoder_block_lens[tid];
const int decoder_used_len = used_list_len[tid];
if (decoder_used_len > 0) {
const int ori_free_list_len =
atomicAdd(free_list_len, decoder_used_len);
#ifdef DEBUG_STEP
printf(
"free block seq_id: %d, free block num: %d, "
"encoder_block_len: %d, ori_free_list_len: %d\n",
tid,
decoder_used_len,
encoder_block_len,
ori_free_list_len);
#endif
for (int i = 0; i < decoder_used_len; i++) {
free_list[ori_free_list_len + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
encoder_block_lens[tid] = 0;
used_list_len[tid] = 0;
}
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
1) /
block_size] == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = atomicAdd(need_block_len, 1);
need_block_list[ori_need_block_len] = tid;
#ifdef DEBUG_STEP
printf("seq_id: %d need block\n", tid);
#endif
}
}
__syncthreads();
while (need_block_len[0] > free_list_len[0]) {
#ifdef DEBUG_STEP
if (tid == 0) {
printf("need_block_len: %d, free_list_len: %d\n",
need_block_len[0],
free_list_len[0]);
}
#endif
// 调度block根据used_list_len从大到小回收block直到满足need_block_len已解码到最后一个block的query不参与调度马上就结束
const int used_block_num =
tid < bsz && !is_block_step[tid] &&
(step_max_block_flag ||
used_list_len[tid] != max_decoder_block_num)
? used_list_len[tid]
: 0;
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
if (tid == 0) {
if (kv_pair.value == 0) {
step_max_block_flag = true;
} else {
const int encoder_block_len = encoder_block_lens[kv_pair.key];
// #ifdef DEBUG_STEP
printf("max_id: %d, max_num: %d, encoder_block_len: %d\n",
kv_pair.key,
kv_pair.value,
encoder_block_len);
// #endif
int *block_table_now =
block_tables + kv_pair.key * block_num_per_seq;
for (int i = 0; i < kv_pair.value; i++) {
free_list[free_list_len[0] + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
step_block_list[step_len[0]] = kv_pair.key;
if (speculate_free_and_dispatch_block(
kv_pair.key,
need_block_list,
need_block_len[0] + in_need_block_list_len)) {
need_block_len[0] -= 1;
in_need_block_list_len += 1;
}
step_len[0] += 1;
free_list_len[0] += kv_pair.value;
stop_flags[kv_pair.key] = true;
is_block_step[kv_pair.key] = true;
seq_lens_this_time[kv_pair.key] = 0;
seq_lens_decoder[kv_pair.key] = 0;
// Note(@wufeisheng): when step, accept num will not be 0 so
// that next step even if this batch member is stepped, save
// output still stream output, so accept num should be set to 0
accept_num[kv_pair.key] = 0;
}
}
__syncthreads();
}
// 为需要block的位置分配block每个位置分配一个block
if (tid < need_block_len[0] + in_need_block_list_len) {
const int need_block_id = need_block_list[tid];
if (need_block_id != -1) {
if (!stop_flags[need_block_id]) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len[need_block_id] += 1;
const int ori_free_list_len = atomicSub(free_list_len, 1);
int *block_table_now =
block_tables + need_block_id * block_num_per_seq;
#ifdef DEBUG_STEP
printf("need_block_id %d\n", need_block_id);
printf("ori_free_list_len %d\n", ori_free_list_len);
printf("max_draft_tokens %d\n", max_draft_tokens);
printf("seq_lens_decoder[need_block_id] %d\n",
seq_lens_decoder[need_block_id]);
printf("free_list[ori_free_list_len - 1] %d\n",
free_list[ori_free_list_len - 1]);
#endif
block_table_now[(seq_lens_decoder[need_block_id] +
max_draft_tokens + 1) /
block_size] = free_list[ori_free_list_len - 1];
}
need_block_list[tid] = -1;
}
}
__syncthreads();
// 计算可以复原的query id
if (tid == 0) {
int ori_free_list_len = free_list_len[0];
int ori_step_len = step_len[0];
if (ori_step_len > 0) {
int ori_step_block_id = step_block_list[ori_step_len - 1];
int tmp_used_len = used_list_len[ori_step_block_id];
// 比之前调度时多分配一个block防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中
const int max_decoder_block_num_this_seq =
max_decoder_block_num - encoder_block_lens[ori_step_block_id];
int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
? tmp_used_len + 1
: max_decoder_block_num_this_seq;
if (ori_step_len > 0 && ori_free_list_len >= used_len) {
// #ifdef DEBUG_STEP
printf(
"recover seq_id: %d, free_list_len: %d, used_list_len: "
"%d\n",
ori_step_block_id,
ori_free_list_len,
used_len);
// #endif
recover_block_list[recover_len[0]] = ori_step_block_id;
is_block_step[ori_step_block_id] = false;
used_list_len[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list[ori_step_len - 1] = -1;
step_len[0] -= 1;
recover_len[0] += 1;
ori_step_len = step_len[0];
if (ori_step_len > 0) {
ori_step_block_id = step_block_list[ori_step_len - 1];
tmp_used_len = used_list_len[ori_step_block_id];
used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
? tmp_used_len + 1
: max_decoder_block_num_this_seq;
}
}
}
need_block_len[0] = 0;
}
}
// 根据上一步计算出的可以复原的query_id进行状态恢复
__global__ void speculate_recover_block(int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
int *ori_seq_lens_encoder,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
int64_t *pre_ids,
int64_t *step_idx,
int *encoder_block_lens,
int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
__shared__ int ori_free_list_len;
if (bid < recover_len[0]) {
const int recover_id = recover_block_list[bid];
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
const int step_idx_now = step_idx[recover_id];
const int seq_len = ori_seq_len_encoder + step_idx_now;
const int encoder_block_len = encoder_block_lens[recover_id];
const int decoder_used_len = used_list_len[recover_id];
int *block_table_now = block_tables + recover_id * block_num_per_seq;
int64_t *input_ids_now = input_ids + recover_id * length;
int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
if (tid == 0) {
seq_lens_this_time[recover_id] = seq_len;
seq_lens_encoder[recover_id] = seq_len;
stop_flags[recover_id] = false;
// input_ids_now[ori_seq_len_encoder + step_idx_now - 1] =
// next_tokens[recover_id]; // next tokens
input_ids_now[0] =
first_token_ids[recover_id]; // set first prompt token
const int ori_free_list_len_tid0 =
atomicSub(free_list_len, decoder_used_len);
ori_free_list_len = ori_free_list_len_tid0;
#ifdef DEBUG_STEP
printf(
"seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, "
"seq_len: %d, ori_free_list_len_tid0: %d, "
"ori_free_list_len: %d\n",
recover_id,
ori_seq_len_encoder,
step_idx_now,
seq_len,
ori_free_list_len_tid0,
ori_free_list_len);
#endif
}
__syncthreads();
// 恢复block table
for (int i = tid; i < decoder_used_len; i += blockDim.x) {
block_table_now[encoder_block_len + i] =
free_list[ori_free_list_len - i - 1];
}
// 恢复input_ids
for (int i = tid; i < step_idx_now; i += blockDim.x) {
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
}
}
if (bid == 0 && tid == 0) {
recover_len[0] = 0;
}
}
void SpeculateStepPaddle(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
constexpr int BlockSize = 256; // bsz <= 256
const int max_decoder_block_num = length / block_size;
// const int max_decoder_block_num = 2048 / block_size -
// encoder_decoder_block_num;
#ifdef DEBUG_STEP
printf(
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
"%d\n",
bsz,
block_num_per_seq,
length,
max_decoder_block_num);
#endif
speculate_free_and_dispatch_block<<<1, BlockSize, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_block_list.data<int>()),
const_cast<int *>(step_lens.data<int>()),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
#ifdef DEBUG_STEP
cudaDeviceSynchronize();
#endif
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
const int grid_size = cpu_recover_lens.data<int>()[0];
#ifdef DEBUG_STEP
printf("grid_size2 %d\n", grid_size);
#endif
if (grid_size > 0) {
speculate_recover_block<<<grid_size, BlockSize, 0, cu_stream>>>(
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(ori_seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int64_t *>(pre_ids.data<int64_t>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
next_tokens.data<int64_t>(),
first_token_ids.data<int64_t>(),
bsz,
block_num_per_seq,
length,
pre_id_length);
#ifdef DEBUG_STEP
cudaDeviceSynchronize();
#endif
}
}
PD_BUILD_STATIC_OP(speculate_step_paddle)
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"is_block_step",
"step_block_list",
"step_lens",
"recover_block_list",
"recover_lens",
"need_block_list",
"need_block_len",
"used_list_len",
"free_list",
"free_list_len",
"input_ids",
"pre_ids",
"step_idx",
"next_tokens",
"first_token_ids",
"accept_num"})
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"max_draft_tokens: int"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"block_tables_out",
"encoder_block_lens_out",
"is_block_step_out",
"step_block_list_out",
"step_lens_out",
"recover_block_list_out",
"recover_lens_out",
"need_block_list_out",
"need_block_len_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out",
"first_token_ids_out"})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"block_tables", "block_tables_out"},
{"encoder_block_lens", "encoder_block_lens_out"},
{"is_block_step", "is_block_step_out"},
{"step_block_list", "step_block_list_out"},
{"step_lens", "step_lens_out"},
{"recover_block_list", "recover_block_list_out"},
{"recover_lens", "recover_lens_out"},
{"need_block_list", "need_block_list_out"},
{"need_block_len", "need_block_len_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"},
{"input_ids", "input_ids_out"},
{"first_token_ids", "first_token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateStepPaddle));

View File

@@ -0,0 +1,389 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "speculate_msg.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__device__ __forceinline__ bool in_need_block_list_schedule(const int &qid,
int *need_block_list,
const int &need_block_len) {
bool res = false;
for (int i = 0; i < need_block_len; i++) {
if (qid == need_block_list[i]) {
res = true;
need_block_list[i] = -1;
break;
}
}
return res;
}
__global__ void speculate_free_and_reschedule(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int* accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, 256> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ bool step_max_block_flag;
__shared__ int in_need_block_list_len;
const int tid = threadIdx.x;
if (tid < bsz) {
if (tid == 0) {
step_max_block_flag = false;
in_need_block_list_len = 0;
}
int *block_table_now = block_tables + tid * block_num_per_seq;
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1 ) / block_size;
if (stop_flags[tid]) {
// 回收block块
first_token_ids[tid] = -1;
const int encoder_block_len = encoder_block_lens[tid];
const int decoder_used_len = used_list_len[tid];
if (decoder_used_len > 0) {
const int ori_free_list_len =
atomicAdd(free_list_len, decoder_used_len);
#ifdef DEBUG_STEP
printf(
"free block seq_id: %d, free block num: %d, "
"encoder_block_len: %d, ori_free_list_len: %d\n",
tid,
decoder_used_len,
encoder_block_len,
ori_free_list_len);
#endif
for (int i = 0; i < decoder_used_len; i++) {
free_list[ori_free_list_len + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
encoder_block_lens[tid] = 0;
used_list_len[tid] = 0;
}
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
1) /
block_size] == -1) {
// 统计需要分配block的位置和总数
#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 1 #####\n", tid);
#endif
const int ori_need_block_len = atomicAdd(need_block_len, 1);
need_block_list[ori_need_block_len] = tid;
#ifdef DEBUG_STEP
printf("seq_id: %d need block\n", tid);
#endif
}
}
#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 2 #####\n", tid);
#endif
__syncthreads();
// 调度block直到满足need_block_len
while (need_block_len[0] > free_list_len[0]) {
if (tid == 0) {
printf("need_block_len: %d, free_list_len: %d\n",
need_block_len[0],
free_list_len[0]);
}
// 调度block根据used_list_len从大到小回收block直到满足need_block_len已解码到最后一个block的query不参与调度马上就结束
const int used_block_num =
tid < bsz ? used_list_len[tid] : 0;
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
if (tid == 0) {
if (kv_pair.value == 0) {
step_max_block_flag = true;
} else {
const int encoder_block_len = encoder_block_lens[kv_pair.key];
printf("step max_id: %d, max_num: %d, encoder_block_len: %d\n",
kv_pair.key,
kv_pair.value,
encoder_block_len);
int *block_table_now =
block_tables + kv_pair.key * block_num_per_seq;
// 回收调度位的block
for (int i = 0; i < kv_pair.value; i++) {
free_list[free_list_len[0] + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
step_block_list[step_len[0]] = kv_pair.key;
// 如果调度位置本次也需要block对应的处理
if (in_need_block_list_schedule(
kv_pair.key,
need_block_list,
need_block_len[0] + in_need_block_list_len)) {
need_block_len[0] -= 1;
in_need_block_list_len += 1;
}
step_len[0] += 1;
free_list_len[0] += kv_pair.value;
stop_flags[kv_pair.key] = true;
seq_lens_this_time[kv_pair.key] = 0;
seq_lens_decoder[kv_pair.key] = 0;
encoder_block_lens[kv_pair.key] = 0;
used_list_len[kv_pair.key] = 0;
printf(
"free block seq_id: %d, free block num: %d, "
"now_free_list_len: %d\n",
(int)kv_pair.key,
(int)kv_pair.value,
(int)free_list_len[0]);
}
}
__syncthreads();
}
#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 3 #####\n", tid);
#endif
// 为需要block的位置分配block每个位置分配一个block
if (tid < need_block_len[0] + in_need_block_list_len) {
const int need_block_id = need_block_list[tid];
if (need_block_id != -1) {
if (!stop_flags[need_block_id]) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len[need_block_id] += 1;
const int ori_free_list_len = atomicSub(free_list_len, 1);
int *block_table_now =
block_tables + need_block_id * block_num_per_seq;
block_table_now[(seq_lens_decoder[need_block_id] +
max_draft_tokens + 1) /
block_size] = free_list[ori_free_list_len - 1];
}
need_block_list[tid] = -1;
}
}
__syncthreads();
// reset need_block_len
if (tid == 0) {
need_block_len[0] = 0;
}
}
// 为不修改接口调用方式,入参暂不改变
void SpeculateStepSchedule(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
constexpr int BlockSize = 256; // bsz <= 256
const int max_decoder_block_num = length / block_size - encoder_decoder_block_num; // 最大输出长度对应的block - 服务为解码分配的block数量
auto step_lens_inkernel = paddle::full({1}, 0, paddle::DataType::INT32, stop_flags.place());
auto step_bs_list = GetEmptyTensor({bsz}, paddle::DataType::INT32, stop_flags.place());
#ifdef DEBUG_STEP
printf(
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
"%d\n",
bsz,
block_num_per_seq,
length,
max_decoder_block_num);
#endif
speculate_free_and_reschedule<<<1, BlockSize, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_bs_list.data<int>()),
const_cast<int *>(step_lens_inkernel.data<int>()),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
#ifdef DEBUG_STEP
cudaDeviceSynchronize();
#endif
// save output
auto step_lens_cpu = step_lens_inkernel.copy_to(paddle::CPUPlace(), false);
if (step_lens_cpu.data<int>()[0] > 0) {
auto step_bs_list_cpu = step_bs_list.copy_to(paddle::CPUPlace(), false);
auto next_tokens = paddle::full({bsz}, -1, paddle::DataType::INT64, paddle::CPUPlace());
for (int i = 0; i < step_lens_cpu.data<int>()[0]; i++) {
const int step_bid = step_bs_list_cpu.data<int>()[i];
next_tokens.data<int64_t>()[step_bid] = -3; // need reschedule
}
const int rank_id = static_cast<int>(stop_flags.place().GetDeviceId());
printf("reschedule rank_id: %d, step_lens: %d", rank_id, step_lens_cpu.data<int>()[0]);
const int64_t* x_data = next_tokens.data<int64_t>();
static struct speculate_msgdata msg_sed;
int msg_queue_id = rank_id;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
msg_queue_id = inference_msg_queue_id_from_env;
} else {
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
<< std::endl;
}
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is perserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
} else {
}
// static key_t key = ftok("/dev/shm", msg_queue_id);
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
msg_sed.mtext[0] = inference_msg_id_from_env;
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
}
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
printf("full msg buffer\n");
}
}
}
PD_BUILD_STATIC_OP(speculate_step_reschedule)
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"is_block_step",
"step_block_list",
"step_lens",
"recover_block_list",
"recover_lens",
"need_block_list",
"need_block_len",
"used_list_len",
"free_list",
"free_list_len",
"input_ids",
"pre_ids",
"step_idx",
"next_tokens",
"first_token_ids",
"accept_num"})
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"max_draft_tokens: int"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"block_tables_out",
"encoder_block_lens_out",
"is_block_step_out",
"step_block_list_out",
"step_lens_out",
"recover_block_list_out",
"recover_lens_out",
"need_block_list_out",
"need_block_len_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out",
"first_token_ids_out"})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"block_tables", "block_tables_out"},
{"encoder_block_lens", "encoder_block_lens_out"},
{"is_block_step", "is_block_step_out"},
{"step_block_list", "step_block_list_out"},
{"step_lens", "step_lens_out"},
{"recover_block_list", "recover_block_list_out"},
{"recover_lens", "recover_lens_out"},
{"need_block_list", "need_block_list_out"},
{"need_block_len", "need_block_len_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"},
{"input_ids", "input_ids_out"},
{"first_token_ids", "first_token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateStepSchedule));

View File

@@ -0,0 +1,268 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
// #define DEBUG_STEP
// 根据上一步计算出的可以复原的query_id进行状态恢复
__global__ void speculate_recover_block(int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
int *ori_seq_lens_encoder,
int *ori_seq_lens_decoder,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
int64_t *pre_ids,
int64_t *step_idx,
int *encoder_block_lens,
int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
__shared__ int ori_free_list_len;
if (bid < recover_len[0]) {
const int recover_id = recover_block_list[bid];
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
const int step_idx_now = step_idx[recover_id];
const int seq_len = ori_seq_len_encoder + step_idx_now;
const int encoder_block_len = encoder_block_lens[recover_id];
const int decoder_used_len = used_list_len[recover_id];
int *block_table_now = block_tables + recover_id * block_num_per_seq;
int64_t *input_ids_now = input_ids + recover_id * length;
int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
if (tid == 0) {
seq_lens_this_time[recover_id] = seq_len;
seq_lens_encoder[recover_id] = seq_len;
seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id];
stop_flags[recover_id] = false;
// input_ids_now[ori_seq_len_encoder + step_idx_now - 1] =
// next_tokens[recover_id]; // next tokens
input_ids_now[0] =
first_token_ids[recover_id]; // set first prompt token
const int ori_free_list_len_tid0 =
atomicSub(free_list_len, decoder_used_len);
ori_free_list_len = ori_free_list_len_tid0;
#ifdef DEBUG_STEP
printf(
"seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, "
"seq_len: %d, ori_free_list_len_tid0: %d, "
"ori_free_list_len: %d\n",
recover_id,
ori_seq_len_encoder,
step_idx_now,
seq_len,
ori_free_list_len_tid0,
ori_free_list_len);
#endif
}
__syncthreads();
// 恢复block table
for (int i = tid; i < decoder_used_len; i += blockDim.x) {
block_table_now[encoder_block_len + i] =
free_list[ori_free_list_len - i - 1];
}
// 恢复input_ids
for (int i = tid; i < step_idx_now; i += blockDim.x) {
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
}
}
if (bid == 0 && tid == 0) {
recover_len[0] = 0;
}
}
void SpeculateStepPaddle(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &ori_seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
constexpr int BlockSize = 256; // bsz <= 256
const int max_decoder_block_num = length / block_size;
// const int max_decoder_block_num = 2048 / block_size -
// encoder_decoder_block_num;
#ifdef DEBUG_STEP
printf(
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
"%d\n",
bsz,
block_num_per_seq,
length,
max_decoder_block_num);
#endif
speculate_free_and_dispatch_block<<<1, BlockSize, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_block_list.data<int>()),
const_cast<int *>(step_lens.data<int>()),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
#ifdef DEBUG_STEP
cudaDeviceSynchronize();
#endif
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
const int grid_size = cpu_recover_lens.data<int>()[0];
#ifdef DEBUG_STEP
printf("grid_size2 %d\n", grid_size);
#endif
if (grid_size > 0) {
speculate_recover_block<<<grid_size, BlockSize, 0, cu_stream>>>(
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(ori_seq_lens_encoder.data<int>()),
const_cast<int *>(ori_seq_lens_decoder.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int64_t *>(pre_ids.data<int64_t>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
next_tokens.data<int64_t>(),
first_token_ids.data<int64_t>(),
bsz,
block_num_per_seq,
length,
pre_id_length);
#ifdef DEBUG_STEP
cudaDeviceSynchronize();
#endif
}
}
PD_BUILD_STATIC_OP(speculate_step_system_cache)
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",
"ori_seq_lens_decoder",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"is_block_step",
"step_block_list",
"step_lens",
"recover_block_list",
"recover_lens",
"need_block_list",
"need_block_len",
"used_list_len",
"free_list",
"free_list_len",
"input_ids",
"pre_ids",
"step_idx",
"next_tokens",
"first_token_ids",
"accept_num"})
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"max_draft_tokens: int"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"block_tables_out",
"encoder_block_lens_out",
"is_block_step_out",
"step_block_list_out",
"step_lens_out",
"recover_block_list_out",
"recover_lens_out",
"need_block_list_out",
"need_block_len_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out",
"first_token_ids_out"})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"block_tables", "block_tables_out"},
{"encoder_block_lens", "encoder_block_lens_out"},
{"is_block_step", "is_block_step_out"},
{"step_block_list", "step_block_list_out"},
{"step_lens", "step_lens_out"},
{"recover_block_list", "recover_block_list_out"},
{"recover_lens", "recover_lens_out"},
{"need_block_list", "need_block_list_out"},
{"need_block_len", "need_block_len_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"},
{"input_ids", "input_ids_out"},
{"first_token_ids", "first_token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateStepPaddle));

View File

@@ -0,0 +1,185 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
// #define DEBUG_SPEC_STOP_SEQS
__global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
int64_t *accept_tokens,
int *accept_nums,
const int64_t *pre_ids,
const int64_t *step_idx,
const int64_t *stop_seqs,
const int *stop_seqs_len,
const int *seq_lens,
const int64_t *end_ids,
const int bs,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
if (tid >= stop_seqs_bs) return;
const int stop_seq_len = stop_seqs_len[tid];
if (stop_seq_len <= 0) return;
if (bid < bs) {
const int64_t *stop_seq_now = stop_seqs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("num %d < stop_seq_len %d\n",
step_idx_now - accept_num + accept_idx + 1,
stop_seq_len);
#endif
continue;
}
// 遍历一个 stop_seqs
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
"accept_token_idx: "
"%d\n",
bid,
tid,
accept_idx,
accept_idx - (stop_seq_len - 1 - i) - 1);
#endif
cur_token_idx =
accept_tokens_now[accept_idx -
(stop_seq_len - 1 - i) - 1];
} else {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
"accept_idx:%d. "
"pre_id_idx: %ld\n",
bid,
tid,
step_idx_now,
accept_idx,
step_idx_now - accept_num + accept_idx -
(stop_seq_len - 1 - i));
#endif
int pre_ids_idx = step_idx_now - accept_num +
accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token即pre_ids[0]可能为23,
// 导致异常结束
if (pre_ids_idx <= 0) {
break;
}
cur_token_idx = pre_ids_now[pre_ids_idx];
}
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"bid:%d. tid:%d, cur_token_idx: %ld. stop_seq_now "
"%ld\n",
bid,
tid,
cur_token_idx,
stop_seq_now[i]);
#endif
if (cur_token_idx != stop_seq_now[i]) {
break;
}
if (i == 0) {
is_end = true;
}
}
}
if (is_end) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("bid:%d end with accept_idx %d", bid, accept_idx);
#endif
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
stop_flags[bid] = true;
}
}
}
}
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
auto cu_stream = accept_tokens.stream();
std::vector<int64_t> shape = accept_tokens.shape();
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
int bs_now = shape[0];
int stop_seqs_bs = stop_seqs_shape[0];
int stop_seqs_max_len = stop_seqs_shape[1];
int pre_ids_len = pre_ids.shape()[1];
int accept_tokens_len = accept_tokens.shape()[1];
int block_size = (stop_seqs_bs + 31) / 32 * 32;
spec_set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
seq_lens.data<int>(),
end_ids.data<int64_t>(),
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
}
PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
.Inputs({"accept_tokens",
"accept_num",
"pre_ids",
"step_idx",
"stop_flags",
"seq_lens",
"stop_seqs",
"stop_seqs_len",
"end_ids"})
.Outputs({"accept_tokens_out", "stop_flags_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"stop_flags", "stop_flags_out"}})
.SetKernelFn(PD_KERNEL(SpecGetStopFlagsMultiSeqs));

View File

@@ -0,0 +1,341 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
template <typename T>
__global__ inline void min_length_logits_process(
T *logits,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *output_padding_offset,
const int *output_cum_offsets,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int64_t end_length,
const int max_seq_len) {
const int token_idx = threadIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
if (bi >= bs) return;
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
for (int i = 0; i < end_length; i++) {
logits[token_idx * length + eos_token_id[i]] = -1e10;
}
}
}
template <>
__global__ inline void min_length_logits_process<half>(
half *logits,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *output_padding_offset,
const int *output_cum_offsets,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int64_t end_length,
const int max_seq_len) {
const int token_idx = threadIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
if (bi >= bs) return;
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
for (int i = 0; i < end_length; i++) {
logits[token_idx * length + eos_token_id[i]] = -1e4;
}
}
}
__global__ void update_repeat_times(const int64_t *pre_ids,
const int64_t *cur_len,
int *repeat_times,
const int *output_padding_offset,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int max_seq_len) {
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
int tid = threadIdx.x;
const int64_t *pre_ids_now = pre_ids + bi * length_id;
int *repeat_times_now = repeat_times + token_idx * length;
for (int i = tid; i < length_id; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
}
}
template <typename T>
__global__ void update_value_by_repeat_times(const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int *output_padding_offset,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int max_seq_len) {
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
const int *repeat_times_now = repeat_times + token_idx * length;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = tid; i < length; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
}
}
template <typename T>
__global__ void ban_bad_words(T *logits,
const int64_t *bad_words_list,
const int *output_padding_offset,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length,
const int max_seq_len) {
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
for (int i = tid; i < bad_words_length; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}
template <paddle::DataType D>
void token_penalty_multi_scores_kernel(
const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_score,
const paddle::Tensor &presence_score,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const int max_seq_len) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto cu_stream = logits.stream();
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = seq_lens_this_time.shape()[0];
int64_t token_num = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t end_length = eos_token_id.shape()[0];
int block_size = (token_num + 32 - 1) / 32 * 32;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
token_num,
bs,
length,
end_length,
max_seq_len);
block_size = (length_id + 32 - 1) / 32 * 32;
block_size = min(block_size, 512);
update_repeat_times<<<token_num, block_size, 0, cu_stream>>>(
pre_ids.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
output_padding_offset.data<int>(),
token_num,
bs,
length,
length_id,
max_seq_len);
block_size = (length + 32 - 1) / 32 * 32;
block_size = min(block_size, 512);
update_value_by_repeat_times<DataType_>
<<<token_num, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(presence_score.data<data_t>())),
temperatures.data<float>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
output_padding_offset.data<int>(),
token_num,
bs,
length,
max_seq_len);
block_size = (length_bad_words + 32 - 1) / 32 * 32;
block_size = min(block_size, 512);
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
token_num,
bs,
length,
length_bad_words,
max_seq_len);
}
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const int max_seq_len) {
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<
paddle::DataType::BFLOAT16>(pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
max_seq_len);
}
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
max_seq_len);
}
case paddle::DataType::FLOAT32: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
max_seq_len);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}
PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
.Inputs({"pre_ids",
"logits",
"penalty_scores",
"frequency_scores",
"presence_scores",
"temperatures",
"bad_tokens",
"cur_len",
"min_len",
"eos_token_id",
"seq_lens_this_time",
"output_padding_offset",
"output_cum_offsets"})
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));

View File

@@ -0,0 +1,42 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void UpdateInputIdsCPU(const paddle::Tensor& input_ids_cpu,
const std::vector<int64_t>& task_input_ids,
const int bid,
const int max_seq_len) {
int64_t* input_ids_cpu_data =
const_cast<int64_t*>(input_ids_cpu.data<int64_t>());
// printf("Input len is %d\n", task_input_ids.size());
for (int i = 0; i < task_input_ids.size(); i++) {
// printf("%lld\n", task_input_ids[i]);
input_ids_cpu_data[bid * max_seq_len + i] = task_input_ids[i];
}
}
PD_BUILD_STATIC_OP(speculate_update_input_ids_cpu)
.Inputs({"input_ids_cpu"})
.Outputs({"input_ids_cpu_out"})
.Attrs({"task_input_ids: std::vector<int64_t>",
"bid: int",
"max_seq_len: int"})
.SetInplaceMap({{"input_ids_cpu", "input_ids_cpu_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputIdsCPU));

View File

@@ -0,0 +1,55 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
__global__ void update_this_time(int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
int real_bsz,
int value) {
int linear_idx = threadIdx.x;
// verify and set stop flags
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] != 0) {
seq_lens_this_time[linear_idx] = value;
} else if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] == 0) {
seq_lens_this_time[linear_idx] = 0;
}
}
}
void UpdateThisTime(const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const int real_bsz,
const int value) {
constexpr int BlockSize = 512;
update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
real_bsz,
value);
}
PD_BUILD_STATIC_OP(speculate_update_seq_lens_this_time)
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
.Outputs({"seq_lens_this_time_out"})
.Attrs({"real_bsz: int", "value: int"})
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
.SetKernelFn(PD_KERNEL(UpdateThisTime));

View File

@@ -0,0 +1,146 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
template <int THREADBLOCK_SIZE>
__global__ void speculate_update(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int real_bsz,
const int max_draft_tokens) {
const int bid = threadIdx.x;
const int accept_num_now = accept_num[bid];
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder[bid] == 0) {
seq_lens_decoder[bid] += accept_num_now;
}
if (seq_lens_this_time[bid] > 1 &&
seq_lens_encoder[bid] ==
0) { // 对于append模式需要根据接收与否确定是否要降低下次draft
// token的数量
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums[bid] =
actual_draft_token_nums[bid] - 1 >= 1
? actual_draft_token_nums[bid] - 1
: 1;
}
}
if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
seq_lens_encoder[bid] = 0;
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
}
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < real_bsz;
}
}
void SpeculateUpdateV2(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step) {
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
constexpr int BlockSize = 512;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
real_bsz,
max_draft_tokens);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_update_v2)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"is_block_step"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdateV2));

View File

@@ -0,0 +1,155 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void speculate_update_v3(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
const int bid = threadIdx.x;
const int accept_num_now = accept_num[bid];
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder[bid] == 0) {
seq_lens_decoder[bid] += accept_num_now;
}
if (seq_lens_this_time[bid] > 1 &&
seq_lens_encoder[bid] ==
0) { // 对于append模式需要根据接收与否确定是否要降低下次draft
// token的数量
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums[bid] =
actual_draft_token_nums[bid] - 1 >= 1
? actual_draft_token_nums[bid] - 1
: 1;
}
}
if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
seq_lens_encoder[bid] = 0;
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
constexpr int BlockSize = 512;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_update_v3<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_draft_tokens);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_update_v3)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"is_block_step",
"stop_nums"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdateV3));

View File

@@ -0,0 +1,478 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <curand_kernel.h>
#include <cstdlib>
#include <string>
#include "helper.h" // NOLINT
__device__ inline bool is_in(const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
}
return false;
}
static uint64_t seed = 0;
static uint64_t offset = 0;
__device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
curandState_t *dev_curand_states,
const int candidate_len,
const float topp) {
const int tid = threadIdx.x;
float sum_scores = 0.0f;
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
__global__ void setup_kernel(curandState_t *state,
const uint64_t seed,
const uint64_t offset,
const int bs,
const bool need_batch_random) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
if (need_batch_random) {
curand_init(seed, i, offset, &state[i]);
} else {
curand_init(seed, 0, offset, &state[i]);
}
}
}
template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify(int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
curandState_t *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
const int bid = threadIdx.x;
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
// verify and set stop flags
int accept_num_now = 1;
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
// printf("bid %d\n", bid);
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入但是因为draft
// tokens会置零因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now =
actual_candidate_len + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (seq_lens_encoder[bid] != 0) {
break;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
// printf("[USE_TOPK] bid %d Top 1 verify write accept
// %d is %lld\n", bid, i, accept_token);
accept_tokens[bid * max_draft_tokens + i] =
accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
// printf("[USE_TOPK] bid %d Top 1 verify write
// accept %d is %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
break;
}
} else {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1],
actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] =
accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
// printf("bid %d Top P verify write accept %d is
// %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
// TopK verify
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window &&
ii < seq_lens_this_time[bid] - 1;
j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
draft_tokens_now[ii + 1]) {
break;
}
}
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] =
accept_token;
// printf(
// "bid %d TopK verify write accept %d
// is "
// "%lld\n",
// bid,
// i,
// accept_token);
if (is_in_end(accept_token,
end_tokens,
end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid *
max_draft_tokens +
i] = end_tokens[0];
// printf("bid %d TopK verify write
// accept %d is %lld\n", bid, i,
// end_tokens[0]);
accept_num_now--;
step_idx[bid]--;
break;
}
}
}
}
break;
}
}
}
// sampling阶段
// 第一种draft_token[i+1]被拒绝需要从verify_tokens_now[i]中选一个
// 第二种i == seq_lens_this_time[bid]-1,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
accept_token = topp_sampling_kernel(
verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len,
dev_curand_states,
actual_candidate_len_value,
topp[bid]);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
}
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
}
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
}
}
accept_num[bid] = accept_num_now;
}
}
}
void SpeculateVerify(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens,
const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
int max_seq_len,
int verify_window,
bool enable_topp) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
auto end_length = end_tokens.shape()[0];
auto max_candidate_len = verify_tokens.shape()[1];
constexpr int BlockSize = 512;
curandState_t *dev_curand_states;
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
dev_curand_states, seed, offset, bsz, true);
seed++;
offset++;
auto err = cudaDeviceSynchronize();
if (err != 0) {
printf("err %d\n", err);
}
err = cudaGetLastError();
if (err != 0) {
printf("err %d\n", err);
}
// printf("inited curand\n");
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
if (use_topk) {
// printf("use_topk \n");
if (enable_topp) {
speculate_verify<true, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
} else {
if (enable_topp) {
speculate_verify<true, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
}
cudaFree(dev_curand_states);
}
PD_BUILD_STATIC_OP(speculate_verify)
.Inputs({"accept_tokens",
"accept_num",
"step_idx",
"seq_lens_encoder",
"seq_lens_decoder",
"stop_flags",
"draft_tokens",
"seq_lens_this_time",
"verify_tokens",
"verify_scores",
"max_dec_len",
"end_tokens",
"is_block_step",
"output_cum_offsets",
"actual_candidate_len",
"actual_draft_token_nums",
"topp"})
.Outputs({"accept_tokens_out",
"accept_num_out",
"step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},
{"stop_flags", "stop_flags_out"}})
.SetKernelFn(PD_KERNEL(SpeculateVerify));

View File

@@ -0,0 +1,624 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
#define WARP_SIZE 32
template <typename T>
__forceinline__ __device__ T
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
}
template <>
__forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync(
unsigned mask, phi::dtype::float16 val, int delta, int width) {
return paddle::float16(__shfl_down_sync(
mask, val.to_half(), static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync(
unsigned mask, phi::dtype::bfloat16 val, int delta, int width) {
return paddle::bfloat16(__shfl_down_sync(
mask, val.to_nv_bfloat16(), static_cast<unsigned>(delta), width));
}
struct BlockPrefixCallbackOp {
// Running prefix
float running_total;
// Constructor
__device__ BlockPrefixCallbackOp(float running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the
// block. Thread-0 is responsible for returning a value for seeding the
// block-wide scan.
__device__ float operator()(float block_aggregate) {
float old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
#define FINAL_MASK 0xFFFFFFFF
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
#define FIXED_TOPK_BASE(topk, ...) \
case (topk): { \
constexpr auto kTopK = topk; \
__VA_ARGS__; \
} break
#define FIXED_TOPK(...) \
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
FIXED_TOPK_BASE(10, ##__VA_ARGS__)
struct SegmentOffsetIter {
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
__host__ __device__ __forceinline__ int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
inline int div_up(int a, int n) { return (a + n - 1) / n; }
template <typename T>
__global__ void FillIndex(T* indices, T num_rows, T num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (T j = row_id; j < num_rows; j += gridDim.x) {
for (T i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
__global__ void SetCountIter(int* count_iter, int num) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int idx = bid * blockDim.x + tid;
for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
count_iter[i] = i;
}
}
template <typename T, int BLOCK_SIZE>
__global__ void top_p_candidates_kernel(T* sorted_probs,
int64_t* sorted_id,
T* out_val,
int64_t* out_id,
int* actual_candidates_lens,
const int vocab_size,
const float topp,
const int candidates_len) {
__shared__ int stop_shared;
__shared__ float rand_p;
const int tid = threadIdx.x;
const int bid = blockIdx.x;
constexpr int NUM_WARPS = BLOCK_SIZE / 32;
const int lane_id = tid % 32;
const int warp_id = tid / 32;
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
__shared__ uint32_t selected_shared[NUM_WARPS];
if (lane_id == 0) {
selected_shared[warp_id] = 0;
}
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
__syncthreads();
int offset = bid * vocab_size;
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int i_activate = 0;
float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_count = (i < vocab_size)
? static_cast<float>(sorted_probs[offset + i])
: 0.f;
BlockScan(temp_storage)
.InclusiveSum(thread_count, thread_offset, prefix_op);
if (i < candidates_len) {
out_id[bid * candidates_len + i] = sorted_id[offset + i];
out_val[bid * candidates_len + i] = sorted_probs[offset + i];
}
uint32_t activate_mask =
__ballot_sync(FINAL_MASK, topp <= thread_offset);
i_activate = i;
if (activate_mask != 0 || i >= candidates_len) {
if (lane_id == 0) {
atomicAdd(&stop_shared, 1);
selected_shared[warp_id] = activate_mask;
}
}
__syncthreads();
if (stop_shared > 0) {
break;
}
}
__syncthreads();
bool skip = (selected_shared[warp_id] > 0) ? false : true;
for (int i = 0; i < warp_id; i++) {
if (selected_shared[i] != 0) {
// If the previous has stopped, skip the current warp
skip = true;
}
}
if (!skip) {
int active_lane_id =
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
if (lane_id == active_lane_id) {
actual_candidates_lens[bid] = i_activate + 1;
}
}
__syncthreads();
if (tid == 0) {
// printf("actual_candidates_lens[%d] %d\n", bid,
// actual_candidates_lens[bid]);
if (actual_candidates_lens[bid] == 0) {
actual_candidates_lens[bid] = candidates_len;
}
}
}
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int id) {
this->v = value;
this->id = id;
}
__device__ __forceinline__ void operator=(const Pair<T>& in) {
v = in.v;
id = in.id;
}
__device__ __forceinline__ bool operator<(const T value) const {
return (static_cast<float>(v) < static_cast<float>(value));
}
__device__ __forceinline__ bool operator>(const T value) const {
return (static_cast<float>(v) > static_cast<float>(value));
}
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
return (static_cast<float>(v) < static_cast<float>(in.v)) ||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
(id > in.id));
}
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
return (static_cast<float>(v) > static_cast<float>(in.v)) ||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
(id < in.id));
}
T v;
int id;
};
template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[],
const Pair<T>& p,
int beam_size) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
topk[0] = p;
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(
Pair<T> topk[], const T* src, int idx, int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
AddTo<T>(topk, tmp, beam_size);
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
const T* src,
int idx,
int dim,
const Pair<T>& max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size);
}
}
idx += BlockSize;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
int* beam,
int beam_size,
const T* src,
bool* firstStep,
bool* is_empty,
Pair<T>* max,
int dim,
const int tid) {
if (*beam > 0) {
int length = (*beam) < beam_size ? *beam : beam_size;
if (*firstStep) {
*firstStep = false;
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(std::numeric_limits<T>::min(), -1);
}
}
if (!(*is_empty)) {
GetTopK<T, BlockSize>(
topk + MaxLength - *beam, src, tid, dim, *max, length);
}
}
*max = topk[MaxLength - 1];
if ((*max).id == -1) *is_empty = true;
*beam = 0;
}
}
template <typename T>
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (static_cast<float>(input.v) < static_cast<float>(tmp_val)) {
input.v = tmp_val;
input.id = tmp_id;
}
}
return input;
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
Pair<T> topk[],
Pair<T> beam_max[],
int* beam,
int* k,
int* count,
const int tid,
const int wid,
const int lane) {
while (true) {
__syncthreads();
Pair<T> input_now = topk[0];
input_now = WarpReduce(input_now);
if (lane == 0) {
shared_max[wid] = input_now;
}
__syncthreads();
input_now = (tid < BlockSize / 32)
? shared_max[lane]
: Pair<T>(std::numeric_limits<T>::min(), -1);
if (wid == 0) {
input_now = WarpReduce(input_now);
if (lane == 0) shared_max[0] = input_now;
}
__syncthreads();
if (tid == 0) {
beam_max[*count] = shared_max[0];
(*count)++;
}
int tid_max = shared_max[0].id % BlockSize;
if (tid == tid_max) {
(*beam)++;
}
if (--(*k) == 0) break;
__syncthreads();
if (tid == tid_max) {
if (*beam < MaxLength) {
topk[0] = topk[*beam];
}
}
if (MaxLength < 5) {
if (*beam >= MaxLength) break;
} else {
unsigned mask = 0u;
mask = __ballot_sync(FINAL_MASK, true);
if (tid_max / 32 == wid) {
if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) ==
MaxLength)
break;
}
}
}
}
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopKFt(
const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id, // [max_cadidate_len, 1]
T* out_val, // [max_cadidate_len, 1]
int* actual_candidates_lens,
int vocab_size,
const int max_cadidate_len,
const int max_seq_len) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + output_padding_offset[token_id];
const int bid = ori_token_id / max_seq_len;
int top_num = TopPBeamTopK;
float top_p_value = static_cast<float>(top_ps[bid]);
__shared__ Pair<T> shared_max[BlockSize / 32];
__shared__ Pair<T> beam_max[TopPBeamTopK];
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
__shared__ int count;
if (tid == 0) {
count = 0;
}
for (int j = 0; j < MaxLength; j++) {
topk[j].set(std::numeric_limits<T>::min(), -1);
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
&beam,
TopPBeamTopK,
src + token_id * vocab_size,
&firststep,
&is_empty,
&max,
vocab_size,
tid);
BlockReduce<T, MaxLength, BlockSize>(shared_max,
topk,
beam_max,
&beam,
&top_num,
&count,
tid,
wid,
lane);
}
if (tid == 0) {
float sum_prob = 0.0f;
bool flag = false;
for (int i = 0; i < TopPBeamTopK; i++) {
out_id[token_id * max_cadidate_len + i] =
static_cast<int64_t>(beam_max[i].id);
out_val[token_id * max_cadidate_len + i] = beam_max[i].v;
float val = static_cast<float>(beam_max[i].v);
sum_prob += val;
if (sum_prob >= top_p_value) {
actual_candidates_lens[token_id] = i + 1;
break;
}
}
}
}
template <typename T, int TopKMaxLength>
void DispatchTopK(const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id, // topk id
T* out_val, // topk val
int* actual_candidates_lens_data,
const int vocab_size,
const int token_num,
const int cadidate_len,
const int max_seq_len,
const cudaStream_t& stream) {
int BlockSize = GetBlockSize(vocab_size);
switch (cadidate_len) {
FIXED_TOPK(switch (BlockSize) {
FIXED_BLOCK_DIM(
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, kTopK, kBlockDim>
<<<token_num, kBlockDim, 0, stream>>>(
src,
top_ps,
output_padding_offset,
out_id,
out_val,
actual_candidates_lens_data,
vocab_size,
cadidate_len,
max_seq_len));
default:
PD_THROW(
"the input data shape has error in the topp_beam_topk "
"kernel.");
});
default:
PD_THROW("the input topk is not implemented.");
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> LaunchTopPCandidates(
const paddle::Tensor& probs, // [token_num, vocab_size]
const paddle::Tensor& top_p, // [token_num]
const paddle::Tensor& output_padding_offset,
const int candidates_len,
const int max_seq_len) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
std::vector<int64_t> input_shape = probs.shape();
const int token_num = input_shape[0];
const int vocab_size = input_shape[1];
auto verify_scores =
paddle::full({token_num, candidates_len}, 0, D, probs.place());
auto verify_tokens = paddle::full(
{token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place());
auto actual_candidate_lens =
paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place());
auto stream = probs.stream();
constexpr int TopKMaxLength = 2;
DispatchTopK<DataType_, TopKMaxLength>(
reinterpret_cast<const DataType_*>(probs.data<data_t>()),
reinterpret_cast<const DataType_*>(top_p.data<data_t>()),
output_padding_offset.data<int>(),
verify_tokens.data<int64_t>(),
reinterpret_cast<DataType_*>(verify_scores.data<data_t>()),
actual_candidate_lens.data<int>(),
vocab_size,
token_num,
candidates_len,
max_seq_len,
stream);
return {verify_scores, verify_tokens, actual_candidate_lens};
}
std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
const paddle::Tensor& probs,
const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset,
int candidates_len,
int max_seq_len) {
switch (probs.type()) {
case paddle::DataType::BFLOAT16:
return LaunchTopPCandidates<paddle::DataType::BFLOAT16>(
probs,
top_p,
output_padding_offset,
candidates_len,
max_seq_len);
break;
case paddle::DataType::FLOAT16:
return LaunchTopPCandidates<paddle::DataType::FLOAT16>(
probs,
top_p,
output_padding_offset,
candidates_len,
max_seq_len);
break;
case paddle::DataType::FLOAT32:
return LaunchTopPCandidates<paddle::DataType::FLOAT32>(
probs,
top_p,
output_padding_offset,
candidates_len,
max_seq_len);
break;
default:
PD_THROW(
"NOT supported data type. "
"Only bfloat16, float16 and float32 are supported. ");
break;
}
}
std::vector<paddle::Tensor> TopPCandidates(
const paddle::Tensor& probs,
const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset,
int candidates_len,
int max_seq_len) {
return DispatchTopPCandidatesWithDtype(
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
}
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
const std::vector<int64_t>& probs_shape,
const std::vector<int64_t>& top_p_shape,
const std::vector<int64_t>& output_padding_offset_shape,
int max_candidates_len) {
int token_num = probs_shape[0];
return {{token_num, max_candidates_len},
{token_num, max_candidates_len},
{token_num}};
}
std::vector<paddle::DataType> TopPCandidatesInferDtype(
const paddle::DataType& probs_dtype,
const paddle::DataType& top_p_dtype,
const paddle::DataType& output_padding_offset_dtype) {
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(top_p_candidates)
.Inputs({"probs", "top_p", "output_padding_offset"})
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
.Attrs({"candidates_len: int", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(TopPCandidates))
.SetInferShapeFn(PD_INFER_SHAPE(TopPCandidatesInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopPCandidatesInferDtype));