mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void draft_model_update_seq_lens_this_time_kernel(
|
||||
const int64_t* base_model_draft_tokens,
|
||||
int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const bool* base_model_stop_flags,
|
||||
int bsz,
|
||||
int base_model_draft_token_len) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
if (!base_model_stop_flags[tid] &&
|
||||
base_model_seq_lens_encoder[tid] == 0) {
|
||||
const int64_t* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_token_len;
|
||||
int token_num = 0;
|
||||
|
||||
for (int i = 0; i < base_model_draft_token_len; ++i) {
|
||||
if (base_model_draft_tokens_now[i] != -1) {
|
||||
token_num++;
|
||||
}
|
||||
}
|
||||
base_model_seq_lens_this_time[tid] = token_num;
|
||||
} else if (base_model_stop_flags[tid]) {
|
||||
base_model_seq_lens_this_time[tid] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_stop_flags) {
|
||||
int real_bsz = base_model_seq_lens_this_time.shape()[0];
|
||||
auto cu_stream = base_model_seq_lens_this_time.stream();
|
||||
int block_size = (real_bsz + 32 - 1) / 32 * 32;
|
||||
int base_model_draft_token_len = base_model_draft_tokens.shape()[1];
|
||||
draft_model_update_seq_lens_this_time_kernel<<<1,
|
||||
block_size,
|
||||
0,
|
||||
cu_stream>>>(
|
||||
base_model_draft_tokens.data<int64_t>(),
|
||||
const_cast<int*>(base_model_seq_lens_this_time.data<int>()),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
real_bsz,
|
||||
base_model_draft_token_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_postprocess)
|
||||
.Inputs({"base_model_draft_tokens",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_stop_flags"})
|
||||
.Outputs({"base_model_draft_tokens_out",
|
||||
"base_model_seq_lens_this_time_out",
|
||||
"base_model_stop_flags_out"})
|
||||
.SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"},
|
||||
{"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_this_time_out"},
|
||||
{"base_model_stop_flags", "base_model_stop_flags_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPostprocess));
|
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
int base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
|
||||
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
|
||||
not_stop_flag = 1;
|
||||
int seq_len_encoder_record = seq_lens_encoder_record[tid];
|
||||
seq_lens_encoder[tid] = seq_len_encoder_record;
|
||||
seq_lens_encoder_record[tid] = -1;
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_decoder[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = not_stop_flag_sum > 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
auto base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||
auto accept_num_now = accept_num[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
// 处理 base_model recover 逻辑
|
||||
// 1. 已处于 recover 状态
|
||||
// if (batch_drop[tid]) {
|
||||
|
||||
// }
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder_record = seq_lens_encoder_record[tid];
|
||||
seq_lens_encoder[tid] = seq_len_encoder_record;
|
||||
seq_lens_encoder_record[tid] = -1;
|
||||
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
|
||||
seq_lens_decoder_record[tid] = 0;
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
|
||||
}
|
||||
} else if (accept_num_now <=
|
||||
max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// Base Model reject stop
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
step_idx[tid] = base_model_step_idx[tid];
|
||||
} else {
|
||||
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
}
|
||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
draft_tokens_now[0] = modified_token;
|
||||
seq_lens_this_time[tid] = 1;
|
||||
|
||||
} else /*Accept all draft tokens*/ {
|
||||
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
seq_lens_this_time[tid] = 2;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_decoder[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = not_stop_flag_sum > 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& seq_lens_encoder_record,
|
||||
const paddle::Tensor& seq_lens_decoder_record,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_encoder_record.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder_record.data<int>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
.Inputs({"draft_tokens",
|
||||
"input_ids",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"seq_lens_encoder_record",
|
||||
"seq_lens_decoder_record",
|
||||
"not_need_stop",
|
||||
"batch_drop",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_seq_lens_decoder",
|
||||
"base_model_step_idx",
|
||||
"base_model_stop_flags",
|
||||
"base_model_is_block_step",
|
||||
"base_model_draft_tokens"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"input_ids_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"seq_lens_encoder_record_out",
|
||||
"seq_lens_decoder_record_out"})
|
||||
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
|
||||
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void update_pre_ids_kernel(const int64_t* draft_tokens,
|
||||
int64_t* pre_ids_all,
|
||||
const bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int64_t* step_idx,
|
||||
int bs,
|
||||
int pre_id_length,
|
||||
int max_draft_token) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs && seq_lens_this_time[tid] != 0 && !stop_flags[tid]) {
|
||||
int64_t* pre_ids_all_now = pre_ids_all + tid * pre_id_length;
|
||||
const int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
|
||||
const int seq_len_this_time = seq_lens_this_time[tid];
|
||||
if (step_idx[tid] - 1 > 0 /*Decoder Step*/) {
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_all_now[step_idx[tid] - i] =
|
||||
draft_token_now[seq_len_this_time - 1 - i];
|
||||
}
|
||||
} else if (step_idx[tid] == 1 /*Encoder Step*/) {
|
||||
pre_ids_all_now[1] = draft_token_now[0];
|
||||
}
|
||||
seq_lens_this_time[tid] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateDraftModelUpdate(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids_all,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx) {
|
||||
int64_t real_bs = seq_lens_this_time.shape()[0];
|
||||
int64_t pre_id_length = pre_ids_all.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
int64_t max_draft_token = draft_tokens.shape()[1];
|
||||
|
||||
int block_size = (real_bs + 32 - 1) / 32 * 32;
|
||||
update_pre_ids_kernel<<<1, block_size, 0, cu_stream>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int64_t*>(pre_ids_all.data<int64_t>()),
|
||||
stop_flags.data<bool>(),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
step_idx.data<int64_t>(),
|
||||
real_bs,
|
||||
pre_id_length,
|
||||
max_draft_token);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_set_value_by_flags)
|
||||
.Inputs({"draft_tokens",
|
||||
"pre_ids_all",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx"})
|
||||
.Outputs({"pre_ids_all_out"})
|
||||
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateDraftModelUpdate));
|
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* pre_ids,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
const int* output_cum_offsets,
|
||||
bool* stop_flags,
|
||||
bool* not_need_stop,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_ids,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int pre_id_length,
|
||||
const int max_base_model_draft_token,
|
||||
const int end_ids_len,
|
||||
const int max_seq_len,
|
||||
const int substep,
|
||||
const bool prefill_one_step_stop) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
auto* draft_token_now = draft_tokens + tid * max_draft_token;
|
||||
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * max_base_model_draft_token;
|
||||
const int next_tokens_start_id =
|
||||
tid * max_seq_len - output_cum_offsets[tid];
|
||||
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||
auto seq_len_this_time = seq_lens_this_time[tid];
|
||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||
auto seq_len_decoder = seq_lens_decoder[tid];
|
||||
|
||||
// 1. update step_idx && seq_lens_dec
|
||||
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
|
||||
int64_t token_this_time = -1;
|
||||
// decoder step
|
||||
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
|
||||
seq_lens_decoder[tid] += seq_len_this_time;
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
|
||||
}
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
||||
// seq_lens_decoder[tid] = seq_lens_encoder[tid];
|
||||
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
pre_ids_now[1] = token_this_time;
|
||||
step_idx[tid] += 1;
|
||||
draft_token_now[0] = token_this_time;
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
}
|
||||
|
||||
// multi_end
|
||||
if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
// max_dec_len
|
||||
} else if (step_idx[tid] >= max_dec_len[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
draft_token_now[seq_len_this_time - 1] = end_ids[0];
|
||||
base_model_draft_tokens_now[substep + 1] = end_ids[0];
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
|
||||
} else {
|
||||
draft_token_now[0] = -1;
|
||||
base_model_draft_tokens_now[substep + 1] = -1;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
|
||||
// 2. set end
|
||||
if (!stop_flags[tid]) {
|
||||
seq_lens_this_time[tid] = 1;
|
||||
} else {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = stop_sum < bsz;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_ids,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_seq_len,
|
||||
const int substep) {
|
||||
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int real_bsz = seq_lens_this_time_shape[0];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
const int end_ids_len = end_ids.shape()[0];
|
||||
const int max_draft_token = draft_tokens.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
const int max_base_model_draft_token = base_model_draft_tokens.shape()[1];
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
|
||||
draft_model_update_kernel<BlockSize><<<1, BlockSize, 0, cu_stream>>>(
|
||||
inter_next_tokens.data<int64_t>(),
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
output_cum_offsets.data<int>(),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
not_need_stop_gpu.data<bool>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_ids.data<int64_t>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
pre_id_length,
|
||||
max_base_model_draft_token,
|
||||
end_ids_len,
|
||||
max_seq_len,
|
||||
substep,
|
||||
prefill_one_step_stop);
|
||||
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_update)
|
||||
.Inputs({"inter_next_tokens",
|
||||
"draft_tokens",
|
||||
"pre_ids",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"output_cum_offsets",
|
||||
"stop_flags",
|
||||
"not_need_stop",
|
||||
"max_dec_len",
|
||||
"end_ids",
|
||||
"base_model_draft_tokens"})
|
||||
.Attrs({"max_seq_len: int", "substep: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"pre_ids_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"stop_flags_out",
|
||||
"not_need_stop_out",
|
||||
"base_model_draft_tokens_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"pre_ids", "pre_ids_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"base_model_draft_tokens", "base_model_draft_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelUpdate));
|
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// #define DEBUG_EAGLE_KERNEL
|
||||
|
||||
__global__ void ComputeOrderKernel(
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* accept_nums,
|
||||
int* position_map,
|
||||
int* output_token_num,
|
||||
const int bsz,
|
||||
const int actual_draft_token_num,
|
||||
const int input_token_num) {
|
||||
int in_offset = 0; // input_offset(long)
|
||||
int out_offset = 0; // output_offset(short)
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
|
||||
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[i];
|
||||
int accept_num = accept_nums[i];
|
||||
int cur_seq_lens_encoder = seq_lens_encoder[i];
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_base_model_seq_lens_this_time%d. cur_seq_lens_this_time%d, accept_num %d\n", i, cur_base_model_seq_lens_this_time, cur_seq_lens_this_time, accept_num);
|
||||
#endif
|
||||
// 1. eagle encoder. Base step=1
|
||||
if (cur_seq_lens_encoder > 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
|
||||
#endif
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
// 2. base model encoder. Base step=0
|
||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||
// 3. New end
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: base=0. draft !=0 \n", i);
|
||||
#endif
|
||||
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
// 4. stopped
|
||||
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
|
||||
} else {
|
||||
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
} else /*Accept all draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||
for (int i = 0; i < output_token_num[0]; i++) {
|
||||
printf("%d ", position_map[i]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void rebuildHiddenStatesKernel(
|
||||
const T* input,
|
||||
const int* position_map,
|
||||
T* out,
|
||||
const int dim_embed,
|
||||
const int elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt; elem_idx += blockDim.x * gridDim.x * VecSize) {
|
||||
|
||||
int ori_token_idx = elem_idx / dim_embed;
|
||||
int token_idx = position_map[ori_token_idx];
|
||||
if (token_idx >= 0) {
|
||||
|
||||
int offset = elem_idx % dim_embed;
|
||||
if (token_idx == 0) {
|
||||
}
|
||||
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const int actual_draft_token_num) {
|
||||
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto input_token_num = input.shape()[0];
|
||||
|
||||
// auto output_token_num = padding_offset.shape()[0];
|
||||
auto dim_embed = input.shape()[1];
|
||||
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
auto position_map = paddle::empty({input_token_num}, seq_lens_this_time.dtype(), input.place());
|
||||
cudaMemsetAsync(position_map.data<int>(), 0xFF, input_token_num * sizeof(seq_lens_this_time.dtype()), seq_lens_this_time.stream());
|
||||
|
||||
auto output_token_num = paddle::empty({1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
ComputeOrderKernel<<<1,1>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
accept_nums.data<int>(),
|
||||
position_map.data<int>(),
|
||||
output_token_num.data<int>(),
|
||||
bsz,
|
||||
actual_draft_token_num,
|
||||
input_token_num);
|
||||
|
||||
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
|
||||
auto out = paddle::empty({output_token_num_cpu, dim_embed}, input.dtype(), input.place());
|
||||
|
||||
constexpr int packSize = VEC_16B / (sizeof(DataType_));
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
|
||||
assert(elem_cnt % packSize == 0);
|
||||
|
||||
int pack_num = elem_cnt / packSize;
|
||||
|
||||
int grid_size = 1;
|
||||
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
constexpr int thread_per_block = 128;
|
||||
|
||||
rebuildHiddenStatesKernel<DataType_, packSize><<<grid_size, thread_per_block>>>(
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>()),
|
||||
position_map.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
dim_embed,
|
||||
elem_cnt);
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const int actual_draft_token_num) {
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
input,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
stop_flags,
|
||||
accept_nums,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
actual_draft_token_num);
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
input,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
stop_flags,
|
||||
accept_nums,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
actual_draft_token_num);
|
||||
}
|
||||
default: {
|
||||
PD_THROW("Not support this data type");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(eagle_get_hidden_states)
|
||||
.Inputs({"input",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"stop_flags",
|
||||
"accept_nums",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder"})
|
||||
.Attrs({"actual_draft_token_num: int"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));
|
@@ -0,0 +1,190 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
// #define DEBUG_EAGLE_KERNEL
|
||||
|
||||
__global__ void computeOrderKernel(
|
||||
const int* last_seq_lens_this_time,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* step_idx,
|
||||
int* src_map,
|
||||
int* output_token_num,
|
||||
int bsz) {
|
||||
int in_offset = 0;
|
||||
int out_offset = 0;
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[i];
|
||||
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_seq_lens_this_time:%d. cur_last_seq_lens_this_time:%d\n", i, cur_seq_lens_this_time, cur_last_seq_lens_this_time);
|
||||
#endif
|
||||
// 1. encoder
|
||||
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d last_step is encoder \n", i);
|
||||
#endif
|
||||
in_offset += 1;
|
||||
src_map[out_offset++] = in_offset - 1;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finish. src_map[%d]=%d \n", i, out_offset - 1, in_offset - 1);
|
||||
#endif
|
||||
// 2. decoder
|
||||
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d is decoder\n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time;
|
||||
src_map[out_offset++] = in_offset - 1;
|
||||
// 3. stop
|
||||
} else {
|
||||
// first token end
|
||||
if (step_idx[i] == 1) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finished in first token \n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
|
||||
// normal end
|
||||
} else {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finished in non-first token \n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||
for (int i = 0; i < output_token_num[0]; i++) {
|
||||
printf("%d ", src_map[i]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int PackSize>
|
||||
__global__ void rebuildSelfHiddenStatesKernel(
|
||||
const T* input,
|
||||
int* src_map,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt) {
|
||||
using LoadT = AlignedVector<T, PackSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int elem_id = global_thread_idx * PackSize; elem_id < elem_cnt; elem_id += blockDim.x * gridDim.x * PackSize) {
|
||||
int output_token_idx = elem_id / dim_embed;
|
||||
int input_token_idx = src_map[output_token_idx];
|
||||
int offset = elem_id % dim_embed;
|
||||
Load<T, PackSize>(input + input_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, PackSize>(src_vec, output + output_token_idx * dim_embed + offset);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor input,
|
||||
const paddle::Tensor last_seq_lens_this_time,
|
||||
const paddle::Tensor seq_lens_this_time,
|
||||
const paddle::Tensor step_idx) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int input_token_num = input.shape()[0];
|
||||
int dim_embed = input.shape()[1];
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
auto src_map = paddle::full({input_token_num}, -1, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
auto output_token_num = paddle::full({1}, 0, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
|
||||
computeOrderKernel<<<1, 1, 0, seq_lens_this_time.stream()>>>(
|
||||
last_seq_lens_this_time.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
src_map.data<int>(),
|
||||
output_token_num.data<int>(),
|
||||
bsz);
|
||||
|
||||
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
|
||||
auto out = paddle::full({output_token_num_cpu, dim_embed}, -1, input.type(), input.place());
|
||||
|
||||
constexpr int packSize = VEC_16B / (sizeof(DataType_));
|
||||
int elem_cnt = output_token_num_cpu * dim_embed;
|
||||
// printf("output_token_num: %d, dim_embed: %d, cnt: %d. packSize: %d\n", output_token_num_cpu, dim_embed,elem_cnt, packSize);
|
||||
assert(elem_cnt % packSize == 0);
|
||||
|
||||
int pack_num = elem_cnt / packSize;
|
||||
|
||||
int grid_size = 1;
|
||||
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
constexpr int threadPerBlock = 128;
|
||||
|
||||
rebuildSelfHiddenStatesKernel<DataType_, packSize><<<grid_size, threadPerBlock, 0, input.stream()>>>(
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>()),
|
||||
src_map.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
dim_embed,
|
||||
elem_cnt);
|
||||
|
||||
|
||||
return {out};
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& last_seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& step_idx) {
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
input,
|
||||
last_seq_lens_this_time,
|
||||
seq_lens_this_time,
|
||||
step_idx);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
input,
|
||||
last_seq_lens_this_time,
|
||||
seq_lens_this_time,
|
||||
step_idx);
|
||||
default:
|
||||
PD_THROW("Not support this data type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
|
||||
.Inputs({"input",
|
||||
"last_seq_lens_this_time",
|
||||
"seq_lens_this_time",
|
||||
"step_idx"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));
|
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void HydraFetchHiddenStatesKernel(const T* output_hidden_states,
|
||||
const int* output_padding_offset,
|
||||
const int* accept_token_num,
|
||||
T* hidden_states,
|
||||
const int bsz,
|
||||
const int max_seq_len,
|
||||
const int hidden_size) {
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + output_padding_offset[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
const int start_ori_token_id = bid * max_seq_len;
|
||||
const int local_token_id = ori_token_id - start_ori_token_id;
|
||||
|
||||
if (local_token_id != accept_token_num[bid] - 1) return;
|
||||
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
|
||||
LoadT vec;
|
||||
|
||||
for (int idx = threadIdx.x * VecSize; idx < hidden_size;
|
||||
idx += blockDim.x * VecSize) {
|
||||
Load(&output_hidden_states[token_id * hidden_size + idx], &vec);
|
||||
Store(vec, &hidden_states[bid * hidden_size + idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> HydraFetchHiddenStatesImpl(
|
||||
const paddle::Tensor& output_hidden_states,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& accept_token_num,
|
||||
const int bsz,
|
||||
const int max_seq_length) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto cu_stream = output_hidden_states.stream();
|
||||
|
||||
auto output_token_num = output_hidden_states.shape()[0];
|
||||
auto hidden_size = output_hidden_states.shape()[1];
|
||||
|
||||
auto hidden_states = paddle::full({bsz, hidden_size},
|
||||
0,
|
||||
output_hidden_states.dtype(),
|
||||
output_hidden_states.place());
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
|
||||
HydraFetchHiddenStatesKernel<data_t, VecSize>
|
||||
<<<output_token_num, 256, 0, cu_stream>>>(
|
||||
output_hidden_states.data<data_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
accept_token_num.data<int>(),
|
||||
hidden_states.data<data_t>(),
|
||||
bsz,
|
||||
max_seq_length,
|
||||
hidden_size);
|
||||
return {hidden_states}; // , enc_token_num, dec_token_num};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> HydraFetchHiddenStates(
|
||||
const paddle::Tensor& output_hidden_states,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& accept_token_num,
|
||||
const int bsz,
|
||||
const int max_seq_length) {
|
||||
switch (output_hidden_states.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return HydraFetchHiddenStatesImpl<paddle::DataType::BFLOAT16>(
|
||||
output_hidden_states,
|
||||
output_padding_offset,
|
||||
accept_token_num,
|
||||
bsz,
|
||||
max_seq_length);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return HydraFetchHiddenStatesImpl<paddle::DataType::FLOAT16>(
|
||||
output_hidden_states,
|
||||
output_padding_offset,
|
||||
accept_token_num,
|
||||
bsz,
|
||||
max_seq_length);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> HydraFetchHiddenStatesInferShape(
|
||||
const std::vector<int64_t>& output_hidden_states_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape,
|
||||
const std::vector<int64_t>& accept_token_num_shape,
|
||||
const int bsz,
|
||||
const int max_seq_length) {
|
||||
return {{bsz, output_hidden_states_shape[1]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> HydraFetchHiddenStatesInferDtype(
|
||||
const paddle::DataType& output_hidden_states_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype,
|
||||
const paddle::DataType& accept_token_num_dtype) {
|
||||
return {output_hidden_states_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(hydra_fetch_hidden_states)
|
||||
.Inputs({"output_hidden_states",
|
||||
"output_padding_offset",
|
||||
"accept_token_num"})
|
||||
.Outputs({"hidden_states"})
|
||||
.Attrs({"bsz: int", "max_seq_length,: int"})
|
||||
.SetKernelFn(PD_KERNEL(HydraFetchHiddenStates))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(HydraFetchHiddenStatesInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(HydraFetchHiddenStatesInferDtype));
|
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
|
||||
};
|
||||
|
||||
void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
int x_dim = x.shape()[1];
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout
|
||||
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output_key: " << key << std::endl;
|
||||
std::cout << "save msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
|
||||
#endif
|
||||
msg_sed.mtext[i + 2] = 2;
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
|
||||
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: ";
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
std::cout << " " << (int)x_data[2*i] << " ";
|
||||
std::cout << " " << (int)x_data[2*i + 1];
|
||||
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_save_first_token)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t",
|
||||
"save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));
|
@@ -0,0 +1,226 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
template<int NUM_THREADS, int MAX_BATCH_SIZE=256>
|
||||
__global__ void mtp_free_and_dispatch_block(
|
||||
bool *base_model_stop_flags,
|
||||
bool *stop_flags,
|
||||
bool *batch_drop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_draft_tokens) {
|
||||
|
||||
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, NUM_THREADS> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
__shared__ int need_block_len;
|
||||
__shared__ int need_block_list[MAX_BATCH_SIZE];
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
if (tid == 0) {
|
||||
need_block_len = 0;
|
||||
}
|
||||
need_block_list[tid] = 0;
|
||||
int *block_table_now = block_tables + tid * block_num_per_seq;
|
||||
if (base_model_stop_flags[tid] || batch_drop[tid]) {
|
||||
// 回收block块
|
||||
const int encoder_block_len = encoder_block_lens[tid];
|
||||
const int decoder_used_len = used_list_len[tid];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len =
|
||||
atomicAdd(free_list_len, decoder_used_len);
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"free block seq_id: %d, free block num: %d, "
|
||||
"encoder_block_len: %d, ori_free_list_len: %d\n",
|
||||
tid,
|
||||
decoder_used_len,
|
||||
encoder_block_len,
|
||||
ori_free_list_len);
|
||||
#endif
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
free_list[ori_free_list_len + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
encoder_block_lens[tid] = 0;
|
||||
used_list_len[tid] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < bsz) {
|
||||
int *block_table_now = block_tables + tid * block_num_per_seq;
|
||||
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size;
|
||||
if (!base_model_stop_flags[tid] && !batch_drop[tid] && max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[max_possible_block_idx] == -1) {
|
||||
int ori_need_block_len = atomicAdd(&need_block_len, 1);
|
||||
need_block_list[ori_need_block_len] = tid;
|
||||
// 统计需要分配block的位置和总数
|
||||
// const int ori_free_list_len = atomicSub(free_list_len, 1);
|
||||
// block_table_now[(seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size] =
|
||||
// free_list[ori_free_list_len - 1];
|
||||
// used_list_len[tid] += 1;
|
||||
|
||||
#ifdef DEBUG_STEP
|
||||
printf("seq_id: %d need block\n", tid);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#ifdef DEBUG_STEP
|
||||
if (tid == 0) {
|
||||
printf("need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
|
||||
}
|
||||
#endif
|
||||
// 这里直接从 bid 0 开始遍历
|
||||
while (need_block_len > free_list_len[0]) {
|
||||
#ifdef DEBUG_STEP
|
||||
if (tid == 0) {
|
||||
printf("in while need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
|
||||
}
|
||||
#endif
|
||||
const int used_block_num =
|
||||
tid < bsz && !base_model_stop_flags[tid] ? used_list_len[tid] : 0;
|
||||
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
|
||||
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
|
||||
|
||||
if (tid == 0) {
|
||||
const int encoder_block_len = encoder_block_lens[kv_pair.key];
|
||||
int *block_table_now =
|
||||
block_tables + kv_pair.key * block_num_per_seq;
|
||||
for (int i = 0; i < kv_pair.value; i++) {
|
||||
free_list[free_list_len[0] + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
const int ori_free_list_len = atomicAdd(free_list_len, kv_pair.value);
|
||||
|
||||
printf(
|
||||
"MTP STEP need_block_len: %d. free_list_len: %d."
|
||||
"Drop bid: %d, free block num: %d, "
|
||||
"encoder_block_len: %d,"
|
||||
"After drop free_list_len %d \n",
|
||||
need_block_len,
|
||||
ori_free_list_len,
|
||||
kv_pair.key,
|
||||
kv_pair.value,
|
||||
encoder_block_len,
|
||||
free_list_len[0]);
|
||||
stop_flags[kv_pair.key] = true;
|
||||
batch_drop[kv_pair.key] = true;
|
||||
seq_lens_this_time[kv_pair.key] = 0;
|
||||
seq_lens_decoder[kv_pair.key] = 0;
|
||||
used_list_len[kv_pair.key] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid < need_block_len) {
|
||||
const int need_block_id = need_block_list[tid];
|
||||
// 这里必须用 batch_drop, 不能用 stop_flags
|
||||
if (!batch_drop[need_block_id]) {
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = atomicSub(free_list_len, 1);
|
||||
int *block_table_now =
|
||||
block_tables + need_block_id * block_num_per_seq;
|
||||
#ifdef DEBUG_STEP
|
||||
printf("bid: %d allocate block_id %d. seq_lens_decoder:%d \n", need_block_id, free_list[ori_free_list_len - 1], seq_lens_decoder[need_block_id]);
|
||||
#endif
|
||||
block_table_now[(seq_lens_decoder[need_block_id] +
|
||||
max_draft_tokens + 1) /
|
||||
block_size] = free_list[ori_free_list_len - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor &base_model_stop_flags,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &batch_drop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
constexpr int BlockSize = 512; // bsz <= 256
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
|
||||
"%d\n",
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
max_decoder_block_num);
|
||||
#endif
|
||||
mtp_free_and_dispatch_block<BlockSize, BlockSize><<<1, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<bool *>(base_model_stop_flags.data<bool>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(batch_drop.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_step_paddle)
|
||||
.Inputs({"base_model_stop_flags",
|
||||
"stop_flags",
|
||||
"batch_drop",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"block_tables",
|
||||
"encoder_block_lens",
|
||||
"used_list_len",
|
||||
"free_list",
|
||||
"free_list_len"})
|
||||
.Attrs({"block_size: int",
|
||||
"max_draft_tokens: int"})
|
||||
.Outputs({"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out"})
|
||||
.SetInplaceMap({{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"used_list_len", "used_list_len_out"},
|
||||
{"free_list", "free_list_out"},
|
||||
{"free_list_len", "free_list_len_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPStepPaddle));
|
212
custom_ops/gpu_ops/speculate_decoding/ngram_match.cc
Normal file
212
custom_ops/gpu_ops/speculate_decoding/ngram_match.cc
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
int sum(const int *value, int num) {
|
||||
int sum_value = 0;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
sum_value += value[i];
|
||||
}
|
||||
return sum_value;
|
||||
}
|
||||
|
||||
void find_candidate_pred_tokens(const int64_t *input_ids,
|
||||
const int64_t *input_ids_len,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *draft_token_num,
|
||||
int64_t *draft_tokens,
|
||||
int32_t *seq_lens_this_time,
|
||||
int32_t *seq_lens_encoder,
|
||||
int32_t *seq_lens_decoder,
|
||||
int64_t *max_dec_len,
|
||||
int64_t input_ids_stride,
|
||||
int64_t pre_ids_stride,
|
||||
int64_t draft_tokens_stride,
|
||||
const int real_batch_size,
|
||||
int max_ngram_size = 3,
|
||||
int max_draft_tokens = 10) {
|
||||
int threshold = 128;
|
||||
char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD");
|
||||
if (env_var) {
|
||||
threshold = std::stoi(env_var);
|
||||
}
|
||||
bool is_insert = false;
|
||||
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
|
||||
if (seq_lens_encoder[batch_idx] > 0) {
|
||||
is_insert = true;
|
||||
}
|
||||
}
|
||||
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
|
||||
max_draft_tokens = std::min(static_cast<int64_t>(
|
||||
draft_token_num[batch_idx]), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
|
||||
if (seq_lens_encoder[batch_idx] > 0) {
|
||||
continue;
|
||||
} else if (seq_lens_decoder[batch_idx] == 0) {
|
||||
seq_lens_this_time[batch_idx] = 0;
|
||||
continue;
|
||||
}
|
||||
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
|
||||
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
|
||||
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
|
||||
const int64_t cur_step_idx = step_idx[batch_idx];
|
||||
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
|
||||
seq_lens_this_time[batch_idx] = 1;
|
||||
if (!is_insert) {
|
||||
auto sum_token_num = sum(seq_lens_this_time, batch_idx);
|
||||
int left_min_token_num = real_batch_size - batch_idx;
|
||||
|
||||
if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
|
||||
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
|
||||
max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens ? tmp_max_draft_tokens : max_draft_tokens;
|
||||
}
|
||||
|
||||
if (sum_token_num + left_min_token_num >= threshold - 1) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
|
||||
// Extract the last n tokens as our search ngram
|
||||
if (cur_step_idx < ngram_size) {
|
||||
continue;
|
||||
}
|
||||
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
|
||||
|
||||
// Iterate through sliding windows of size ngram_size
|
||||
bool match_input = false;
|
||||
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match = true;
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_input_ids[i + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_input_ids_len);
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
|
||||
memcpy(cur_draft_tokens + 1, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
// To break the current batch_idx for-loop
|
||||
ngram_size = 0;
|
||||
match_input = true;
|
||||
break;
|
||||
// }
|
||||
}
|
||||
}
|
||||
if (!match_input) {
|
||||
for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match = true;
|
||||
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_pre_ids[i + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (match) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_step_idx);
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
|
||||
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
|
||||
memcpy(cur_draft_tokens + 1, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
ngram_size = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int real_batch_size,
|
||||
const int max_ngram_size,
|
||||
const int max_draft_tokens) {
|
||||
|
||||
auto input_ids_shape = input_ids.shape();
|
||||
const int64_t input_ids_stride = input_ids_shape[1];
|
||||
|
||||
auto pre_ids_shape = pre_ids.shape();
|
||||
const int64_t pre_ids_stride = pre_ids_shape[1];
|
||||
|
||||
auto draft_tokens_shape = draft_tokens.shape();
|
||||
const int64_t draft_tokens_stride = draft_tokens_shape[1];
|
||||
|
||||
find_candidate_pred_tokens(input_ids.data<int64_t>(),
|
||||
input_ids_len.data<int64_t>(),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
draft_token_num.data<int>(),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
|
||||
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
|
||||
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
|
||||
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
|
||||
input_ids_stride,
|
||||
pre_ids_stride,
|
||||
draft_tokens_stride,
|
||||
real_batch_size,
|
||||
max_ngram_size,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ngram_match)
|
||||
.Inputs({"input_ids",
|
||||
"input_ids_len",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"draft_token_num",
|
||||
"draft_tokens",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"max_dec_len"})
|
||||
.Attrs({"real_batch_size: int", "max_ngram_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
|
||||
.SetKernelFn(PD_KERNEL(NgramMatch))
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});
|
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void CalculateKernel(int32_t* sum_draft_num,
|
||||
int32_t* sum_accept_num,
|
||||
const int32_t* accept_nums,
|
||||
const int32_t* seq_lens_this_time,
|
||||
const int32_t* seq_lens_decoder,
|
||||
const bool* stop_flags,
|
||||
int real_bsz) {
|
||||
int tid = threadIdx.x;
|
||||
int draft_num = 0, accept_num = 0;
|
||||
if (tid < real_bsz) {
|
||||
if (seq_lens_decoder[tid] > 0 &&
|
||||
seq_lens_this_time[tid] != seq_lens_decoder[tid]) {
|
||||
draft_num = seq_lens_this_time[tid] - 1;
|
||||
accept_num = accept_nums[tid] - 1;
|
||||
} else if (seq_lens_this_time[tid] > 0 &&
|
||||
stop_flags[tid]) { // last step
|
||||
draft_num = seq_lens_this_time[tid] - 1;
|
||||
accept_num = accept_nums[tid];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int draft_nums_sum = BlockReduce(temp_storage).Sum(draft_num);
|
||||
int accept_nums_sum = BlockReduce(temp_storage).Sum(accept_num);
|
||||
|
||||
if (tid == 0 && draft_nums_sum != 0) {
|
||||
sum_draft_num[0] += draft_nums_sum;
|
||||
sum_accept_num[0] += accept_nums_sum;
|
||||
}
|
||||
}
|
||||
|
||||
void Calculate(const paddle::Tensor& sum_draft_num,
|
||||
const paddle::Tensor& sum_accept_num,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
constexpr int BLOCK_SIZE = 512;
|
||||
|
||||
CalculateKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, accept_nums.stream()>>>(
|
||||
const_cast<int*>(sum_draft_num.data<int32_t>()),
|
||||
const_cast<int*>(sum_accept_num.data<int32_t>()),
|
||||
accept_nums.data<int32_t>(),
|
||||
seq_lens_this_time.data<int32_t>(),
|
||||
seq_lens_decoder.data<int32_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
real_bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_calcu_accept_ratio)
|
||||
.Inputs({"sum_draft_num",
|
||||
"sum_accept_num",
|
||||
"accept_nums",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"stop_flags"})
|
||||
.Outputs({"sum_draft_num_out", "sum_accept_num_out"})
|
||||
.SetInplaceMap({{"sum_draft_num", "sum_draft_num_out"},
|
||||
{"sum_accept_num", "sum_accept_num_out"}})
|
||||
.SetKernelFn(PD_KERNEL(Calculate));
|
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
__global__ void speculate_clear_accept_nums_kernel(int* accept_num,
|
||||
const int* seq_lens_decoder,
|
||||
const int max_bsz) {
|
||||
const int bid = threadIdx.x;
|
||||
if (bid >= max_bsz) return;
|
||||
accept_num[bid] = seq_lens_decoder[bid] == 0 ? 0 : accept_num[bid];
|
||||
}
|
||||
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder) {
|
||||
// printf("enter clear \n");
|
||||
const int max_bsz = seq_lens_decoder.shape()[0];
|
||||
speculate_clear_accept_nums_kernel<<<1, 1024, 0, accept_num.stream()>>>(
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_clear_accept_nums)
|
||||
.Inputs({"accept_num", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_decoder_out"})
|
||||
.SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateClearAcceptNums));
|
118
custom_ops/gpu_ops/speculate_decoding/speculate_get_output.cc
Normal file
118
custom_ops/gpu_ops/speculate_decoding/speculate_get_output.cc
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
int64_t mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, accept_num*bsz, tokens...
|
||||
};
|
||||
|
||||
void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id,
|
||||
bool get_each_rank) {
|
||||
if (!get_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
|
||||
0,
|
||||
IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
if (ret == -1) {
|
||||
out_data[0] = -2;
|
||||
out_data[1] = 0;
|
||||
return;
|
||||
}
|
||||
int bsz = msg_rcv.mtext[1];
|
||||
|
||||
for (int64_t i = 0; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) {
|
||||
out_data[i] = (int64_t)msg_rcv.mtext[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SpeculateGetOutputStatic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
bool get_each_rank) {
|
||||
SpeculateGetOutput(x, rank_id, wait_flag, 1, get_each_rank);
|
||||
}
|
||||
|
||||
void SpeculateGetOutputDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id,
|
||||
bool get_each_rank) {
|
||||
SpeculateGetOutput(x, rank_id, wait_flag, msg_queue_id, get_each_rank);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_output)
|
||||
.Inputs({"x"})
|
||||
.Attrs({"rank_id: int64_t", "wait_flag: bool", "get_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetOutputStatic));
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_output_dynamic)
|
||||
.Inputs({"x"})
|
||||
.Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int", "get_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetOutputDynamic));
|
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateGetOutputPaddingOffsetKernel(
|
||||
int* output_padding_offset,
|
||||
int* output_cum_offsets,
|
||||
const int* output_cum_offsets_tmp,
|
||||
const int* seq_lens_output,
|
||||
const int max_seq_len) {
|
||||
// get padding offset of each batch
|
||||
const int bi = blockIdx.x;
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : output_cum_offsets_tmp[bi - 1];
|
||||
for (int i = ti; i < seq_lens_output[bi]; i += blockDim.x) {
|
||||
output_padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
}
|
||||
if (ti == 0) {
|
||||
output_cum_offsets[bi] = cum_offset;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
const paddle::Tensor& seq_lens_output,
|
||||
const int max_seq_len) {
|
||||
auto cu_stream = output_cum_offsets_tmp.stream();
|
||||
std::vector<int64_t> output_cum_offsets_tmp_shape =
|
||||
output_cum_offsets_tmp.shape();
|
||||
const int bsz = output_cum_offsets_tmp_shape[0];
|
||||
auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
auto output_padding_offset = paddle::full({cpu_out_token_num},
|
||||
0,
|
||||
paddle::DataType::INT32,
|
||||
output_cum_offsets_tmp.place());
|
||||
auto output_cum_offsets =
|
||||
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
|
||||
|
||||
SpeculateGetOutputPaddingOffsetKernel<<<bsz, 256, 0, cu_stream>>>(
|
||||
output_padding_offset.data<int>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
output_cum_offsets_tmp.data<int>(),
|
||||
seq_lens_output.data<int>(),
|
||||
max_seq_len);
|
||||
|
||||
return {output_padding_offset, output_cum_offsets};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> SpeculateGetOutputPaddingOffsetInferShape(
|
||||
const std::vector<int64_t>& output_cum_offsets_tmp_shape,
|
||||
const std::vector<int64_t>& out_token_num_shape,
|
||||
const std::vector<int64_t>& seq_lens_output_shape) {
|
||||
int64_t bsz = output_cum_offsets_tmp_shape[0];
|
||||
return {{-1}, {bsz}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
|
||||
const paddle::DataType& output_cum_offsets_tmp_dtype,
|
||||
const paddle::DataType& out_token_num_dtype,
|
||||
const paddle::DataType& seq_lens_output_dtype) {
|
||||
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_output_padding_offset)
|
||||
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
|
||||
.Outputs({"output_padding_offset", "output_cum_offsets"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype));
|
@@ -0,0 +1,155 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateRemovePadding(int64_t* output_data,
|
||||
const int64_t* input_data,
|
||||
const int64_t* draft_tokens,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const int* cum_offsets,
|
||||
const int sequence_length,
|
||||
const int max_draft_tokens) {
|
||||
const int bi = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
|
||||
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
||||
if (seq_lens_encoder[bi] > 0) {
|
||||
const int src_seq_id = bi * sequence_length + i;
|
||||
output_data[tgt_seq_id] = input_data[src_seq_id];
|
||||
} else {
|
||||
const int src_seq_id = bi * max_draft_tokens + i;
|
||||
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len) {
|
||||
// get padding offset of each batch
|
||||
const int bi = blockIdx.x;
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
}
|
||||
if (ti == 0) {
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
cu_seqlens_k[bi + 1] = cum_seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
auto cu_stream = input_ids.stream();
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
const int max_draft_tokens = draft_tokens.shape()[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto padding_offset = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
||||
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||
padding_offset.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length);
|
||||
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
max_draft_tokens);
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
padding_offset,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t>& input_ids_shape,
|
||||
const std::vector<int64_t>& draft_tokens_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& token_num_shape,
|
||||
const std::vector<int64_t>& seq_len_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
||||
const paddle::DataType& input_ids_dtype,
|
||||
const paddle::DataType& draft_tokens_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& token_num_dtype,
|
||||
const paddle::DataType& seq_len_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||
.Inputs({"input_ids",
|
||||
"draft_tokens",
|
||||
"cum_offsets",
|
||||
"token_num",
|
||||
"seq_len",
|
||||
"seq_lens_encoder"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"cum_offsets_out",
|
||||
"padding_offset",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype));
|
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateGetSeqLensOutputKernel(int* seq_lens_output,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int real_bsz) {
|
||||
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
|
||||
if (seq_lens_this_time[bid] == 0) {
|
||||
continue;
|
||||
} else if (seq_lens_this_time[bid] == 1) {
|
||||
seq_lens_output[bid] = 1;
|
||||
} else if (seq_lens_encoder[bid] != 0) {
|
||||
seq_lens_output[bid] = 1;
|
||||
} else {
|
||||
seq_lens_output[bid] = seq_lens_this_time[bid];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
const int bsz = seq_lens_this_time_shape[0];
|
||||
|
||||
auto seq_lens_output = paddle::full(
|
||||
{bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
|
||||
SpeculateGetSeqLensOutputKernel<<<1, 256, 0, cu_stream>>>(
|
||||
seq_lens_output.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
bsz);
|
||||
|
||||
return {seq_lens_output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> SpeculateGetSeqLensOutputInferShape(
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape) {
|
||||
int64_t bsz = seq_lens_this_time_shape[0];
|
||||
return {{bsz}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> SpeculateGetSeqLensOutputInferDtype(
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype) {
|
||||
return {seq_lens_this_time_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_seq_lens_output)
|
||||
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_output"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetSeqLensOutputInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetSeqLensOutputInferDtype));
|
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateHydraSetScoreThresholdKernel(
|
||||
float* threshold,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* accept_num,
|
||||
const int real_bsz,
|
||||
const float default_threshold = 0.3,
|
||||
const float upper_threshold = 0.8,
|
||||
const float lower_threshold = 0.0,
|
||||
const float threshold_step = 0.1,
|
||||
const float threshold_step_fac = 0.5) {
|
||||
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
threshold[bid] = default_threshold;
|
||||
} else if (seq_lens_this_time[bid] <= 1) {
|
||||
continue;
|
||||
} else if (accept_num[bid] >= seq_lens_this_time[bid] &&
|
||||
threshold[bid] >
|
||||
lower_threshold + threshold_step * threshold_step_fac) {
|
||||
threshold[bid] -= threshold_step * threshold_step_fac;
|
||||
} else if (accept_num[bid] < seq_lens_this_time[bid] &&
|
||||
threshold[bid] < upper_threshold - threshold_step) {
|
||||
threshold[bid] += threshold_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateHydraSetScoreThreshold(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& threshold) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
const int bsz = seq_lens_this_time_shape[0];
|
||||
|
||||
SpeculateHydraSetScoreThresholdKernel<<<1, 256, 0, cu_stream>>>(
|
||||
const_cast<float*>(threshold.data<float>()),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
accept_num.data<int>(),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_hydra_set_score_threshold)
|
||||
.Inputs(
|
||||
{"seq_lens_this_time", "seq_lens_encoder", "accept_num", "threshold"})
|
||||
.Outputs({"threshold_out"})
|
||||
.SetInplaceMap({{"threshold", "threshold_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateHydraSetScoreThreshold));
|
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void hydra_update_this_time(int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const float* topk_scores,
|
||||
const float* score_threshold,
|
||||
int real_bsz,
|
||||
int idx) {
|
||||
int linear_idx = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
|
||||
if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] != 0) {
|
||||
if (topk_scores[linear_idx] > score_threshold[linear_idx] &&
|
||||
seq_lens_this_time[linear_idx] == idx + 1) {
|
||||
seq_lens_this_time[linear_idx]++;
|
||||
}
|
||||
} else if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] == 0) {
|
||||
seq_lens_this_time[linear_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HydraUpdateThisTime(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& topk_scores,
|
||||
const paddle::Tensor& score_threshold,
|
||||
const int real_bsz,
|
||||
const int idx) {
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
hydra_update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
topk_scores.data<float>(),
|
||||
score_threshold.data<float>(),
|
||||
real_bsz,
|
||||
idx);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_hydra_update_seqlens_this_time)
|
||||
.Inputs({"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"topk_scores",
|
||||
"score_threshold"})
|
||||
.Outputs({"seq_lens_this_time_out"})
|
||||
.Attrs({"real_bsz: int", "idx: int"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(HydraUpdateThisTime));
|
32
custom_ops/gpu_ops/speculate_decoding/speculate_msg.h
Normal file
32
custom_ops/gpu_ops/speculate_decoding/speculate_msg.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
// TODO: replace all msgdata in speculate-decoding
|
||||
struct speculate_msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, tokens
|
||||
};
|
@@ -0,0 +1,149 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void RebuildAppendPaddingKernel(
|
||||
T *out,
|
||||
const T *full_hidden_states,
|
||||
const int *cum_offset,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *output_padding_offset,
|
||||
const int seq_len,
|
||||
const int dim_embed,
|
||||
const size_t elem_nums) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
|
||||
const int out_token_id = i / dim_embed;
|
||||
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
|
||||
const int bi = ori_token_id / seq_len;
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
|
||||
else if (seq_len_encoder[bi] != 0) {
|
||||
seq_id = seq_len_encoder[bi] - 1;
|
||||
}
|
||||
|
||||
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
|
||||
const int bias_idx = i % dim_embed;
|
||||
|
||||
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &out[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor& full_hidden_states,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int max_seq_len) {
|
||||
// src: [token_num, dim_embed]
|
||||
// dst: [batch_size, 1, dim_embed]
|
||||
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
|
||||
int dim_embed = full_hidden_states.shape()[1];
|
||||
int output_token_num = output_padding_offset.shape()[0];
|
||||
int elem_nums = output_token_num * dim_embed;
|
||||
constexpr int PackSize = VEC_16B / sizeof(DataType_);
|
||||
assert(elem_nums % PackSize == 0);
|
||||
|
||||
auto out = paddle::full({output_token_num, dim_embed}, 0, full_hidden_states.dtype(), full_hidden_states.place());
|
||||
|
||||
int pack_num = elem_nums / PackSize;
|
||||
const int threads_per_block = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
|
||||
cum_offsets.data<int32_t>(),
|
||||
seq_len_encoder.data<int32_t>(),
|
||||
seq_len_decoder.data<int32_t>(),
|
||||
output_padding_offset.data<int32_t>(),
|
||||
max_seq_len,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> RebuildAppendPadding(
|
||||
const paddle::Tensor& full_hidden_states,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int max_seq_len) {
|
||||
|
||||
|
||||
switch (full_hidden_states.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildAppendPaddingInferShape(
|
||||
const std::vector<int64_t>& full_hidden_states_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& seq_len_encoder_shape,
|
||||
const std::vector<int64_t>& seq_len_decoder_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape) {
|
||||
const int64_t output_token_num = output_padding_offset_shape[0];
|
||||
const int64_t dim_embed = full_hidden_states_shape[1];
|
||||
std::vector<int64_t> out_shape = {output_token_num, dim_embed};
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
|
||||
const paddle::DataType& full_hidden_states_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& seq_len_encoder_dtype,
|
||||
const paddle::DataType& seq_len_decoder_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype) {
|
||||
return {full_hidden_states_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
|
||||
.Inputs({"full_hidden_states",
|
||||
"cum_offsets",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"output_padding_offset"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));
|
163
custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc
Normal file
163
custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
|
||||
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
int save_each_rank) {
|
||||
// printf("enter save output");
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int max_draft_tokens = accept_tokens.shape()[1];
|
||||
|
||||
auto accept_tokens_cpu = accept_tokens.copy_to(paddle::CPUPlace(), true);
|
||||
auto accept_num_cpu = accept_num.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
|
||||
int* accept_num_data = accept_num_cpu.data<int>();
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static struct msgdata msg_sed;
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout
|
||||
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = accept_tokens.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
|
||||
for (int i = 2; i < MAX_BSZ + 2; i++) {
|
||||
if (i - 2 >= bsz) {
|
||||
msg_sed.mtext[i] = 0;
|
||||
} else {
|
||||
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
|
||||
}
|
||||
}
|
||||
for (int i = MAX_BSZ + 2; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2;
|
||||
i++) {
|
||||
int token_id = i - MAX_BSZ - 2;
|
||||
int bid = token_id / MAX_DRAFT_TOKENS;
|
||||
int local_token_id = token_id % MAX_DRAFT_TOKENS;
|
||||
if (token_id / MAX_DRAFT_TOKENS >= bsz) {
|
||||
msg_sed.mtext[i] = 0;
|
||||
} else {
|
||||
msg_sed.mtext[i] =
|
||||
accept_tokens_data[bid * max_draft_tokens + local_token_id];
|
||||
}
|
||||
}
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
|
||||
0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
SpeculateSaveWithOutputMsg(
|
||||
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
|
||||
}
|
||||
|
||||
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
SpeculateSaveWithOutputMsg(
|
||||
accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_save_output)
|
||||
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"accept_tokens", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
|
||||
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"accept_tokens", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));
|
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
|
||||
const int64_t *accept_tokens,
|
||||
const int *accept_num,
|
||||
const bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int max_draft_tokens) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs && !stop_flags[tid]) {
|
||||
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
|
||||
const int64_t *accept_tokens_now =
|
||||
accept_tokens + tid * max_draft_tokens;
|
||||
const int seq_len_dec = seq_lens_decoder[tid];
|
||||
const int seq_len_enc = seq_lens_encoder[tid];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
||||
// printf("step_idx[tid] %d\n", step_idx[tid]);
|
||||
if (step_idx[tid] >= 0) {
|
||||
for (int i = 0; i < accept_num[tid]; i++) {
|
||||
pre_ids_all_now[step_idx[tid] - i] =
|
||||
accept_tokens_now[accept_num[tid] - 1 - i];
|
||||
// printf("pre_ids_all_now[step_idx[tid] - i] %d \n",
|
||||
// pre_ids_all_now[step_idx[tid] - i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_idx) {
|
||||
// printf("enter set value \n");
|
||||
auto cu_stream = stop_flags.stream();
|
||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||
|
||||
int bs = seq_lens_this_time.shape()[0];
|
||||
int length = pre_ids_all_shape[1];
|
||||
int max_draft_tokens = accept_tokens.shape()[1];
|
||||
int block_size = (bs + 32 - 1) / 32 * 32;
|
||||
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
stop_flags.data<bool>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
bs,
|
||||
length,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
|
||||
.Inputs({"pre_ids_all",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx"})
|
||||
.Outputs({"pre_ids_all_out"})
|
||||
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));
|
481
custom_ops/gpu_ops/speculate_decoding/speculate_step.cu
Normal file
481
custom_ops/gpu_ops/speculate_decoding/speculate_step.cu
Normal file
@@ -0,0 +1,481 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
// #define DEBUG_STEP
|
||||
|
||||
__device__ bool speculate_free_and_dispatch_block(const int &qid,
|
||||
int *need_block_list,
|
||||
const int &need_block_len) {
|
||||
bool res = false;
|
||||
for (int i = 0; i < need_block_len; i++) {
|
||||
if (qid == need_block_list[i]) {
|
||||
res = true;
|
||||
need_block_list[i] = -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
__global__ void speculate_free_and_dispatch_block(
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, 256> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ bool step_max_block_flag;
|
||||
__shared__ int in_need_block_list_len;
|
||||
const int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
if (tid == 0) {
|
||||
step_max_block_flag = false;
|
||||
in_need_block_list_len = 0;
|
||||
}
|
||||
int *block_table_now = block_tables + tid * block_num_per_seq;
|
||||
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1 ) / block_size;
|
||||
if (stop_flags[tid] && !is_block_step[tid]) {
|
||||
// 回收block块
|
||||
first_token_ids[tid] = -1;
|
||||
const int encoder_block_len = encoder_block_lens[tid];
|
||||
const int decoder_used_len = used_list_len[tid];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len =
|
||||
atomicAdd(free_list_len, decoder_used_len);
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"free block seq_id: %d, free block num: %d, "
|
||||
"encoder_block_len: %d, ori_free_list_len: %d\n",
|
||||
tid,
|
||||
decoder_used_len,
|
||||
encoder_block_len,
|
||||
ori_free_list_len);
|
||||
#endif
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
free_list[ori_free_list_len + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
encoder_block_lens[tid] = 0;
|
||||
used_list_len[tid] = 0;
|
||||
}
|
||||
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
|
||||
1) /
|
||||
block_size] == -1) {
|
||||
// 统计需要分配block的位置和总数
|
||||
const int ori_need_block_len = atomicAdd(need_block_len, 1);
|
||||
need_block_list[ori_need_block_len] = tid;
|
||||
#ifdef DEBUG_STEP
|
||||
printf("seq_id: %d need block\n", tid);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
while (need_block_len[0] > free_list_len[0]) {
|
||||
#ifdef DEBUG_STEP
|
||||
if (tid == 0) {
|
||||
printf("need_block_len: %d, free_list_len: %d\n",
|
||||
need_block_len[0],
|
||||
free_list_len[0]);
|
||||
}
|
||||
#endif
|
||||
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束)
|
||||
const int used_block_num =
|
||||
tid < bsz && !is_block_step[tid] &&
|
||||
(step_max_block_flag ||
|
||||
used_list_len[tid] != max_decoder_block_num)
|
||||
? used_list_len[tid]
|
||||
: 0;
|
||||
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
|
||||
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
|
||||
|
||||
if (tid == 0) {
|
||||
if (kv_pair.value == 0) {
|
||||
step_max_block_flag = true;
|
||||
} else {
|
||||
const int encoder_block_len = encoder_block_lens[kv_pair.key];
|
||||
// #ifdef DEBUG_STEP
|
||||
printf("max_id: %d, max_num: %d, encoder_block_len: %d\n",
|
||||
kv_pair.key,
|
||||
kv_pair.value,
|
||||
encoder_block_len);
|
||||
// #endif
|
||||
int *block_table_now =
|
||||
block_tables + kv_pair.key * block_num_per_seq;
|
||||
for (int i = 0; i < kv_pair.value; i++) {
|
||||
free_list[free_list_len[0] + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
step_block_list[step_len[0]] = kv_pair.key;
|
||||
if (speculate_free_and_dispatch_block(
|
||||
kv_pair.key,
|
||||
need_block_list,
|
||||
need_block_len[0] + in_need_block_list_len)) {
|
||||
need_block_len[0] -= 1;
|
||||
in_need_block_list_len += 1;
|
||||
}
|
||||
step_len[0] += 1;
|
||||
free_list_len[0] += kv_pair.value;
|
||||
stop_flags[kv_pair.key] = true;
|
||||
is_block_step[kv_pair.key] = true;
|
||||
seq_lens_this_time[kv_pair.key] = 0;
|
||||
seq_lens_decoder[kv_pair.key] = 0;
|
||||
// Note(@wufeisheng): when step, accept num will not be 0 so
|
||||
// that next step even if this batch member is stepped, save
|
||||
// output still stream output, so accept num should be set to 0
|
||||
accept_num[kv_pair.key] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// 为需要block的位置分配block,每个位置分配一个block
|
||||
if (tid < need_block_len[0] + in_need_block_list_len) {
|
||||
const int need_block_id = need_block_list[tid];
|
||||
if (need_block_id != -1) {
|
||||
if (!stop_flags[need_block_id]) {
|
||||
// 如果需要的位置正好是上一步中被释放的位置,不做处理
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = atomicSub(free_list_len, 1);
|
||||
int *block_table_now =
|
||||
block_tables + need_block_id * block_num_per_seq;
|
||||
#ifdef DEBUG_STEP
|
||||
printf("need_block_id %d\n", need_block_id);
|
||||
printf("ori_free_list_len %d\n", ori_free_list_len);
|
||||
printf("max_draft_tokens %d\n", max_draft_tokens);
|
||||
printf("seq_lens_decoder[need_block_id] %d\n",
|
||||
seq_lens_decoder[need_block_id]);
|
||||
printf("free_list[ori_free_list_len - 1] %d\n",
|
||||
free_list[ori_free_list_len - 1]);
|
||||
#endif
|
||||
block_table_now[(seq_lens_decoder[need_block_id] +
|
||||
max_draft_tokens + 1) /
|
||||
block_size] = free_list[ori_free_list_len - 1];
|
||||
}
|
||||
need_block_list[tid] = -1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 计算可以复原的query id
|
||||
if (tid == 0) {
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
int ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
int ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
int tmp_used_len = used_list_len[ori_step_block_id];
|
||||
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
|
||||
const int max_decoder_block_num_this_seq =
|
||||
max_decoder_block_num - encoder_block_lens[ori_step_block_id];
|
||||
int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
|
||||
? tmp_used_len + 1
|
||||
: max_decoder_block_num_this_seq;
|
||||
if (ori_step_len > 0 && ori_free_list_len >= used_len) {
|
||||
// #ifdef DEBUG_STEP
|
||||
printf(
|
||||
"recover seq_id: %d, free_list_len: %d, used_list_len: "
|
||||
"%d\n",
|
||||
ori_step_block_id,
|
||||
ori_free_list_len,
|
||||
used_len);
|
||||
// #endif
|
||||
recover_block_list[recover_len[0]] = ori_step_block_id;
|
||||
is_block_step[ori_step_block_id] = false;
|
||||
used_list_len[ori_step_block_id] = used_len;
|
||||
ori_free_list_len -= used_len;
|
||||
step_block_list[ori_step_len - 1] = -1;
|
||||
step_len[0] -= 1;
|
||||
recover_len[0] += 1;
|
||||
ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
tmp_used_len = used_list_len[ori_step_block_id];
|
||||
used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
|
||||
? tmp_used_len + 1
|
||||
: max_decoder_block_num_this_seq;
|
||||
}
|
||||
}
|
||||
}
|
||||
need_block_len[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// 根据上一步计算出的可以复原的query_id进行状态恢复
|
||||
__global__ void speculate_recover_block(int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
int64_t *pre_ids,
|
||||
int64_t *step_idx,
|
||||
int *encoder_block_lens,
|
||||
int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
__shared__ int ori_free_list_len;
|
||||
if (bid < recover_len[0]) {
|
||||
const int recover_id = recover_block_list[bid];
|
||||
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
|
||||
const int step_idx_now = step_idx[recover_id];
|
||||
const int seq_len = ori_seq_len_encoder + step_idx_now;
|
||||
const int encoder_block_len = encoder_block_lens[recover_id];
|
||||
const int decoder_used_len = used_list_len[recover_id];
|
||||
int *block_table_now = block_tables + recover_id * block_num_per_seq;
|
||||
int64_t *input_ids_now = input_ids + recover_id * length;
|
||||
int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
|
||||
if (tid == 0) {
|
||||
seq_lens_this_time[recover_id] = seq_len;
|
||||
seq_lens_encoder[recover_id] = seq_len;
|
||||
stop_flags[recover_id] = false;
|
||||
// input_ids_now[ori_seq_len_encoder + step_idx_now - 1] =
|
||||
// next_tokens[recover_id]; // next tokens
|
||||
input_ids_now[0] =
|
||||
first_token_ids[recover_id]; // set first prompt token
|
||||
const int ori_free_list_len_tid0 =
|
||||
atomicSub(free_list_len, decoder_used_len);
|
||||
ori_free_list_len = ori_free_list_len_tid0;
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, "
|
||||
"seq_len: %d, ori_free_list_len_tid0: %d, "
|
||||
"ori_free_list_len: %d\n",
|
||||
recover_id,
|
||||
ori_seq_len_encoder,
|
||||
step_idx_now,
|
||||
seq_len,
|
||||
ori_free_list_len_tid0,
|
||||
ori_free_list_len);
|
||||
#endif
|
||||
}
|
||||
__syncthreads();
|
||||
// 恢复block table
|
||||
for (int i = tid; i < decoder_used_len; i += blockDim.x) {
|
||||
block_table_now[encoder_block_len + i] =
|
||||
free_list[ori_free_list_len - i - 1];
|
||||
}
|
||||
// 恢复input_ids
|
||||
for (int i = tid; i < step_idx_now; i += blockDim.x) {
|
||||
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
if (bid == 0 && tid == 0) {
|
||||
recover_len[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const paddle::Tensor &accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
constexpr int BlockSize = 256; // bsz <= 256
|
||||
const int max_decoder_block_num = length / block_size;
|
||||
// const int max_decoder_block_num = 2048 / block_size -
|
||||
// encoder_decoder_block_num;
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
|
||||
"%d\n",
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
max_decoder_block_num);
|
||||
#endif
|
||||
speculate_free_and_dispatch_block<<<1, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int *>(step_block_list.data<int>()),
|
||||
const_cast<int *>(step_lens.data<int>()),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<int *>(need_block_list.data<int>()),
|
||||
const_cast<int *>(need_block_len.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
#ifdef DEBUG_STEP
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
|
||||
const int grid_size = cpu_recover_lens.data<int>()[0];
|
||||
#ifdef DEBUG_STEP
|
||||
printf("grid_size2 %d\n", grid_size);
|
||||
#endif
|
||||
if (grid_size > 0) {
|
||||
speculate_recover_block<<<grid_size, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(ori_seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(pre_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
first_token_ids.data<int64_t>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
#ifdef DEBUG_STEP
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_step_paddle)
|
||||
.Inputs({"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"ori_seq_lens_encoder",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"block_tables",
|
||||
"encoder_block_lens",
|
||||
"is_block_step",
|
||||
"step_block_list",
|
||||
"step_lens",
|
||||
"recover_block_list",
|
||||
"recover_lens",
|
||||
"need_block_list",
|
||||
"need_block_len",
|
||||
"used_list_len",
|
||||
"free_list",
|
||||
"free_list_len",
|
||||
"input_ids",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"next_tokens",
|
||||
"first_token_ids",
|
||||
"accept_num"})
|
||||
.Attrs({"block_size: int",
|
||||
"encoder_decoder_block_num: int",
|
||||
"max_draft_tokens: int"})
|
||||
.Outputs({"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"block_tables_out",
|
||||
"encoder_block_lens_out",
|
||||
"is_block_step_out",
|
||||
"step_block_list_out",
|
||||
"step_lens_out",
|
||||
"recover_block_list_out",
|
||||
"recover_lens_out",
|
||||
"need_block_list_out",
|
||||
"need_block_len_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out",
|
||||
"input_ids_out",
|
||||
"first_token_ids_out"})
|
||||
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"encoder_block_lens", "encoder_block_lens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"step_block_list", "step_block_list_out"},
|
||||
{"step_lens", "step_lens_out"},
|
||||
{"recover_block_list", "recover_block_list_out"},
|
||||
{"recover_lens", "recover_lens_out"},
|
||||
{"need_block_list", "need_block_list_out"},
|
||||
{"need_block_len", "need_block_len_out"},
|
||||
{"used_list_len", "used_list_len_out"},
|
||||
{"free_list", "free_list_out"},
|
||||
{"free_list_len", "free_list_len_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"first_token_ids", "first_token_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateStepPaddle));
|
@@ -0,0 +1,389 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "speculate_msg.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__device__ __forceinline__ bool in_need_block_list_schedule(const int &qid,
|
||||
int *need_block_list,
|
||||
const int &need_block_len) {
|
||||
bool res = false;
|
||||
for (int i = 0; i < need_block_len; i++) {
|
||||
if (qid == need_block_list[i]) {
|
||||
res = true;
|
||||
need_block_list[i] = -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
__global__ void speculate_free_and_reschedule(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int* accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, 256> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ bool step_max_block_flag;
|
||||
__shared__ int in_need_block_list_len;
|
||||
const int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
if (tid == 0) {
|
||||
step_max_block_flag = false;
|
||||
in_need_block_list_len = 0;
|
||||
}
|
||||
int *block_table_now = block_tables + tid * block_num_per_seq;
|
||||
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1 ) / block_size;
|
||||
if (stop_flags[tid]) {
|
||||
// 回收block块
|
||||
first_token_ids[tid] = -1;
|
||||
const int encoder_block_len = encoder_block_lens[tid];
|
||||
const int decoder_used_len = used_list_len[tid];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len =
|
||||
atomicAdd(free_list_len, decoder_used_len);
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"free block seq_id: %d, free block num: %d, "
|
||||
"encoder_block_len: %d, ori_free_list_len: %d\n",
|
||||
tid,
|
||||
decoder_used_len,
|
||||
encoder_block_len,
|
||||
ori_free_list_len);
|
||||
#endif
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
free_list[ori_free_list_len + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
encoder_block_lens[tid] = 0;
|
||||
used_list_len[tid] = 0;
|
||||
}
|
||||
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
|
||||
1) /
|
||||
block_size] == -1) {
|
||||
// 统计需要分配block的位置和总数
|
||||
#ifdef DEBUG_STEP
|
||||
printf("step seq_id:%d, ##### pin 1 #####\n", tid);
|
||||
#endif
|
||||
const int ori_need_block_len = atomicAdd(need_block_len, 1);
|
||||
need_block_list[ori_need_block_len] = tid;
|
||||
#ifdef DEBUG_STEP
|
||||
printf("seq_id: %d need block\n", tid);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#ifdef DEBUG_STEP
|
||||
printf("step seq_id:%d, ##### pin 2 #####\n", tid);
|
||||
#endif
|
||||
__syncthreads();
|
||||
|
||||
// 调度block,直到满足need_block_len
|
||||
while (need_block_len[0] > free_list_len[0]) {
|
||||
if (tid == 0) {
|
||||
printf("need_block_len: %d, free_list_len: %d\n",
|
||||
need_block_len[0],
|
||||
free_list_len[0]);
|
||||
}
|
||||
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束)
|
||||
const int used_block_num =
|
||||
tid < bsz ? used_list_len[tid] : 0;
|
||||
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
|
||||
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
|
||||
if (tid == 0) {
|
||||
if (kv_pair.value == 0) {
|
||||
step_max_block_flag = true;
|
||||
} else {
|
||||
const int encoder_block_len = encoder_block_lens[kv_pair.key];
|
||||
printf("step max_id: %d, max_num: %d, encoder_block_len: %d\n",
|
||||
kv_pair.key,
|
||||
kv_pair.value,
|
||||
encoder_block_len);
|
||||
int *block_table_now =
|
||||
block_tables + kv_pair.key * block_num_per_seq;
|
||||
// 回收调度位的block
|
||||
for (int i = 0; i < kv_pair.value; i++) {
|
||||
free_list[free_list_len[0] + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
step_block_list[step_len[0]] = kv_pair.key;
|
||||
// 如果调度位置本次也需要block,对应的处理
|
||||
if (in_need_block_list_schedule(
|
||||
kv_pair.key,
|
||||
need_block_list,
|
||||
need_block_len[0] + in_need_block_list_len)) {
|
||||
need_block_len[0] -= 1;
|
||||
in_need_block_list_len += 1;
|
||||
}
|
||||
step_len[0] += 1;
|
||||
free_list_len[0] += kv_pair.value;
|
||||
stop_flags[kv_pair.key] = true;
|
||||
seq_lens_this_time[kv_pair.key] = 0;
|
||||
seq_lens_decoder[kv_pair.key] = 0;
|
||||
encoder_block_lens[kv_pair.key] = 0;
|
||||
used_list_len[kv_pair.key] = 0;
|
||||
printf(
|
||||
"free block seq_id: %d, free block num: %d, "
|
||||
"now_free_list_len: %d\n",
|
||||
(int)kv_pair.key,
|
||||
(int)kv_pair.value,
|
||||
(int)free_list_len[0]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
#ifdef DEBUG_STEP
|
||||
printf("step seq_id:%d, ##### pin 3 #####\n", tid);
|
||||
#endif
|
||||
// 为需要block的位置分配block,每个位置分配一个block
|
||||
if (tid < need_block_len[0] + in_need_block_list_len) {
|
||||
const int need_block_id = need_block_list[tid];
|
||||
if (need_block_id != -1) {
|
||||
if (!stop_flags[need_block_id]) {
|
||||
// 如果需要的位置正好是上一步中被释放的位置,不做处理
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = atomicSub(free_list_len, 1);
|
||||
int *block_table_now =
|
||||
block_tables + need_block_id * block_num_per_seq;
|
||||
block_table_now[(seq_lens_decoder[need_block_id] +
|
||||
max_draft_tokens + 1) /
|
||||
block_size] = free_list[ori_free_list_len - 1];
|
||||
}
|
||||
need_block_list[tid] = -1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// reset need_block_len
|
||||
if (tid == 0) {
|
||||
need_block_len[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// 为不修改接口调用方式,入参暂不改变
|
||||
void SpeculateStepSchedule(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const paddle::Tensor &accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
constexpr int BlockSize = 256; // bsz <= 256
|
||||
const int max_decoder_block_num = length / block_size - encoder_decoder_block_num; // 最大输出长度对应的block - 服务为解码分配的block数量
|
||||
auto step_lens_inkernel = paddle::full({1}, 0, paddle::DataType::INT32, stop_flags.place());
|
||||
auto step_bs_list = GetEmptyTensor({bsz}, paddle::DataType::INT32, stop_flags.place());
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
|
||||
"%d\n",
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
max_decoder_block_num);
|
||||
#endif
|
||||
speculate_free_and_reschedule<<<1, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int *>(step_bs_list.data<int>()),
|
||||
const_cast<int *>(step_lens_inkernel.data<int>()),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<int *>(need_block_list.data<int>()),
|
||||
const_cast<int *>(need_block_len.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
#ifdef DEBUG_STEP
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
// save output
|
||||
auto step_lens_cpu = step_lens_inkernel.copy_to(paddle::CPUPlace(), false);
|
||||
if (step_lens_cpu.data<int>()[0] > 0) {
|
||||
auto step_bs_list_cpu = step_bs_list.copy_to(paddle::CPUPlace(), false);
|
||||
auto next_tokens = paddle::full({bsz}, -1, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
for (int i = 0; i < step_lens_cpu.data<int>()[0]; i++) {
|
||||
const int step_bid = step_bs_list_cpu.data<int>()[i];
|
||||
next_tokens.data<int64_t>()[step_bid] = -3; // need reschedule
|
||||
}
|
||||
const int rank_id = static_cast<int>(stop_flags.place().GetDeviceId());
|
||||
printf("reschedule rank_id: %d, step_lens: %d", rank_id, step_lens_cpu.data<int>()[0]);
|
||||
const int64_t* x_data = next_tokens.data<int64_t>();
|
||||
static struct speculate_msgdata msg_sed;
|
||||
int msg_queue_id = rank_id;
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
} else {
|
||||
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
|
||||
<< std::endl;
|
||||
}
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
} else {
|
||||
}
|
||||
// static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
msg_sed.mtype = 1;
|
||||
msg_sed.mtext[0] = inference_msg_id_from_env;
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 2; i < bsz + 2; i++) {
|
||||
msg_sed.mtext[i] = (int)x_data[i - 2];
|
||||
}
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_step_reschedule)
|
||||
.Inputs({"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"ori_seq_lens_encoder",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"block_tables",
|
||||
"encoder_block_lens",
|
||||
"is_block_step",
|
||||
"step_block_list",
|
||||
"step_lens",
|
||||
"recover_block_list",
|
||||
"recover_lens",
|
||||
"need_block_list",
|
||||
"need_block_len",
|
||||
"used_list_len",
|
||||
"free_list",
|
||||
"free_list_len",
|
||||
"input_ids",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"next_tokens",
|
||||
"first_token_ids",
|
||||
"accept_num"})
|
||||
.Attrs({"block_size: int",
|
||||
"encoder_decoder_block_num: int",
|
||||
"max_draft_tokens: int"})
|
||||
.Outputs({"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"block_tables_out",
|
||||
"encoder_block_lens_out",
|
||||
"is_block_step_out",
|
||||
"step_block_list_out",
|
||||
"step_lens_out",
|
||||
"recover_block_list_out",
|
||||
"recover_lens_out",
|
||||
"need_block_list_out",
|
||||
"need_block_len_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out",
|
||||
"input_ids_out",
|
||||
"first_token_ids_out"})
|
||||
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"encoder_block_lens", "encoder_block_lens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"step_block_list", "step_block_list_out"},
|
||||
{"step_lens", "step_lens_out"},
|
||||
{"recover_block_list", "recover_block_list_out"},
|
||||
{"recover_lens", "recover_lens_out"},
|
||||
{"need_block_list", "need_block_list_out"},
|
||||
{"need_block_len", "need_block_len_out"},
|
||||
{"used_list_len", "used_list_len_out"},
|
||||
{"free_list", "free_list_out"},
|
||||
{"free_list_len", "free_list_len_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"first_token_ids", "first_token_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateStepSchedule));
|
@@ -0,0 +1,268 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
// #define DEBUG_STEP
|
||||
|
||||
// 根据上一步计算出的可以复原的query_id进行状态恢复
|
||||
__global__ void speculate_recover_block(int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *ori_seq_lens_encoder,
|
||||
int *ori_seq_lens_decoder,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
int64_t *pre_ids,
|
||||
int64_t *step_idx,
|
||||
int *encoder_block_lens,
|
||||
int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
__shared__ int ori_free_list_len;
|
||||
if (bid < recover_len[0]) {
|
||||
const int recover_id = recover_block_list[bid];
|
||||
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
|
||||
const int step_idx_now = step_idx[recover_id];
|
||||
const int seq_len = ori_seq_len_encoder + step_idx_now;
|
||||
const int encoder_block_len = encoder_block_lens[recover_id];
|
||||
const int decoder_used_len = used_list_len[recover_id];
|
||||
int *block_table_now = block_tables + recover_id * block_num_per_seq;
|
||||
int64_t *input_ids_now = input_ids + recover_id * length;
|
||||
int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
|
||||
if (tid == 0) {
|
||||
seq_lens_this_time[recover_id] = seq_len;
|
||||
seq_lens_encoder[recover_id] = seq_len;
|
||||
seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id];
|
||||
stop_flags[recover_id] = false;
|
||||
// input_ids_now[ori_seq_len_encoder + step_idx_now - 1] =
|
||||
// next_tokens[recover_id]; // next tokens
|
||||
input_ids_now[0] =
|
||||
first_token_ids[recover_id]; // set first prompt token
|
||||
const int ori_free_list_len_tid0 =
|
||||
atomicSub(free_list_len, decoder_used_len);
|
||||
ori_free_list_len = ori_free_list_len_tid0;
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, "
|
||||
"seq_len: %d, ori_free_list_len_tid0: %d, "
|
||||
"ori_free_list_len: %d\n",
|
||||
recover_id,
|
||||
ori_seq_len_encoder,
|
||||
step_idx_now,
|
||||
seq_len,
|
||||
ori_free_list_len_tid0,
|
||||
ori_free_list_len);
|
||||
#endif
|
||||
}
|
||||
__syncthreads();
|
||||
// 恢复block table
|
||||
for (int i = tid; i < decoder_used_len; i += blockDim.x) {
|
||||
block_table_now[encoder_block_len + i] =
|
||||
free_list[ori_free_list_len - i - 1];
|
||||
}
|
||||
// 恢复input_ids
|
||||
for (int i = tid; i < step_idx_now; i += blockDim.x) {
|
||||
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
if (bid == 0 && tid == 0) {
|
||||
recover_len[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &ori_seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const paddle::Tensor &accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
constexpr int BlockSize = 256; // bsz <= 256
|
||||
const int max_decoder_block_num = length / block_size;
|
||||
// const int max_decoder_block_num = 2048 / block_size -
|
||||
// encoder_decoder_block_num;
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
|
||||
"%d\n",
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
max_decoder_block_num);
|
||||
#endif
|
||||
speculate_free_and_dispatch_block<<<1, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int *>(step_block_list.data<int>()),
|
||||
const_cast<int *>(step_lens.data<int>()),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<int *>(need_block_list.data<int>()),
|
||||
const_cast<int *>(need_block_len.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
#ifdef DEBUG_STEP
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
|
||||
const int grid_size = cpu_recover_lens.data<int>()[0];
|
||||
#ifdef DEBUG_STEP
|
||||
printf("grid_size2 %d\n", grid_size);
|
||||
#endif
|
||||
if (grid_size > 0) {
|
||||
speculate_recover_block<<<grid_size, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(ori_seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(ori_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(pre_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
first_token_ids.data<int64_t>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
#ifdef DEBUG_STEP
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_step_system_cache)
|
||||
.Inputs({"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"ori_seq_lens_encoder",
|
||||
"ori_seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"block_tables",
|
||||
"encoder_block_lens",
|
||||
"is_block_step",
|
||||
"step_block_list",
|
||||
"step_lens",
|
||||
"recover_block_list",
|
||||
"recover_lens",
|
||||
"need_block_list",
|
||||
"need_block_len",
|
||||
"used_list_len",
|
||||
"free_list",
|
||||
"free_list_len",
|
||||
"input_ids",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"next_tokens",
|
||||
"first_token_ids",
|
||||
"accept_num"})
|
||||
.Attrs({"block_size: int",
|
||||
"encoder_decoder_block_num: int",
|
||||
"max_draft_tokens: int"})
|
||||
.Outputs({"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"block_tables_out",
|
||||
"encoder_block_lens_out",
|
||||
"is_block_step_out",
|
||||
"step_block_list_out",
|
||||
"step_lens_out",
|
||||
"recover_block_list_out",
|
||||
"recover_lens_out",
|
||||
"need_block_list_out",
|
||||
"need_block_len_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out",
|
||||
"input_ids_out",
|
||||
"first_token_ids_out"})
|
||||
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"encoder_block_lens", "encoder_block_lens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"step_block_list", "step_block_list_out"},
|
||||
{"step_lens", "step_lens_out"},
|
||||
{"recover_block_list", "recover_block_list_out"},
|
||||
{"recover_lens", "recover_lens_out"},
|
||||
{"need_block_list", "need_block_list_out"},
|
||||
{"need_block_len", "need_block_len_out"},
|
||||
{"used_list_len", "used_list_len_out"},
|
||||
{"free_list", "free_list_out"},
|
||||
{"free_list_len", "free_list_len_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"first_token_ids", "first_token_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateStepPaddle));
|
@@ -0,0 +1,185 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
// #define DEBUG_SPEC_STOP_SEQS
|
||||
|
||||
__global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
int64_t *accept_tokens,
|
||||
int *accept_nums,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int64_t *stop_seqs,
|
||||
const int *stop_seqs_len,
|
||||
const int *seq_lens,
|
||||
const int64_t *end_ids,
|
||||
const int bs,
|
||||
const int accept_tokens_len,
|
||||
const int stop_seqs_bs,
|
||||
const int stop_seqs_max_len,
|
||||
const int pre_ids_len) {
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
if (tid >= stop_seqs_bs) return;
|
||||
const int stop_seq_len = stop_seqs_len[tid];
|
||||
if (stop_seq_len <= 0) return;
|
||||
if (bid < bs) {
|
||||
const int64_t *stop_seq_now = stop_seqs + tid * stop_seqs_max_len;
|
||||
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int accept_num = accept_nums[bid];
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
if (!stop_flags[bid]) {
|
||||
int accept_idx = 0;
|
||||
bool is_end = false;
|
||||
// 遍历起始位置
|
||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("num %d < stop_seq_len %d\n",
|
||||
step_idx_now - accept_num + accept_idx + 1,
|
||||
stop_seq_len);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
// 遍历一个 stop_seqs
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
int64_t cur_token_idx = -1;
|
||||
|
||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
||||
if (stop_seq_len - 1 - i < accept_idx) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||
"accept_token_idx: "
|
||||
"%d\n",
|
||||
bid,
|
||||
tid,
|
||||
accept_idx,
|
||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
||||
#endif
|
||||
cur_token_idx =
|
||||
accept_tokens_now[accept_idx -
|
||||
(stop_seq_len - 1 - i) - 1];
|
||||
} else {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
|
||||
"accept_idx:%d. "
|
||||
"pre_id_idx: %ld\n",
|
||||
bid,
|
||||
tid,
|
||||
step_idx_now,
|
||||
accept_idx,
|
||||
step_idx_now - accept_num + accept_idx -
|
||||
(stop_seq_len - 1 - i));
|
||||
#endif
|
||||
int pre_ids_idx = step_idx_now - accept_num +
|
||||
accept_idx - (stop_seq_len - 1 - i);
|
||||
// EC3
|
||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||
// 导致异常结束
|
||||
if (pre_ids_idx <= 0) {
|
||||
break;
|
||||
}
|
||||
cur_token_idx = pre_ids_now[pre_ids_idx];
|
||||
}
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"bid:%d. tid:%d, cur_token_idx: %ld. stop_seq_now "
|
||||
"%ld\n",
|
||||
bid,
|
||||
tid,
|
||||
cur_token_idx,
|
||||
stop_seq_now[i]);
|
||||
#endif
|
||||
if (cur_token_idx != stop_seq_now[i]) {
|
||||
break;
|
||||
}
|
||||
if (i == 0) {
|
||||
is_end = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_end) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("bid:%d end with accept_idx %d", bid, accept_idx);
|
||||
#endif
|
||||
|
||||
accept_nums[bid] = accept_idx;
|
||||
accept_tokens_now[accept_idx - 1] = end_ids[0];
|
||||
stop_flags[bid] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const paddle::Tensor &end_ids) {
|
||||
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
|
||||
auto cu_stream = accept_tokens.stream();
|
||||
std::vector<int64_t> shape = accept_tokens.shape();
|
||||
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
|
||||
int bs_now = shape[0];
|
||||
int stop_seqs_bs = stop_seqs_shape[0];
|
||||
int stop_seqs_max_len = stop_seqs_shape[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
|
||||
int block_size = (stop_seqs_bs + 31) / 32 * 32;
|
||||
spec_set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
stop_seqs.data<int64_t>(),
|
||||
stop_seqs_len.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
end_ids.data<int64_t>(),
|
||||
bs_now,
|
||||
accept_tokens_len,
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
|
||||
.Inputs({"accept_tokens",
|
||||
"accept_num",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"stop_flags",
|
||||
"seq_lens",
|
||||
"stop_seqs",
|
||||
"stop_seqs_len",
|
||||
"end_ids"})
|
||||
.Outputs({"accept_tokens_out", "stop_flags_out"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"stop_flags", "stop_flags_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpecGetStopFlagsMultiSeqs));
|
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
template <typename T>
|
||||
__global__ inline void min_length_logits_process(
|
||||
T *logits,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int *output_padding_offset,
|
||||
const int *output_cum_offsets,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t end_length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; i++) {
|
||||
logits[token_idx * length + eos_token_id[i]] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ inline void min_length_logits_process<half>(
|
||||
half *logits,
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int *output_padding_offset,
|
||||
const int *output_cum_offsets,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t end_length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; i++) {
|
||||
logits[token_idx * length + eos_token_id[i]] = -1e4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void update_repeat_times(const int64_t *pre_ids,
|
||||
const int64_t *cur_len,
|
||||
int *repeat_times,
|
||||
const int *output_padding_offset,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
int tid = threadIdx.x;
|
||||
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
||||
int *repeat_times_now = repeat_times + token_idx * length;
|
||||
for (int i = tid; i < length_id; i += blockDim.x) {
|
||||
int64_t id = pre_ids_now[i];
|
||||
if (id < 0) break;
|
||||
atomicAdd(&repeat_times_now[id], 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void update_value_by_repeat_times(const int *repeat_times,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_score,
|
||||
const T *presence_score,
|
||||
const float *temperatures,
|
||||
T *logits,
|
||||
const int *output_padding_offset,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
const int *repeat_times_now = repeat_times + token_idx * length;
|
||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||
float beta = static_cast<float>(frequency_score[bi]);
|
||||
float gamma = static_cast<float>(presence_score[bi]);
|
||||
for (int i = tid; i < length; i += blockDim.x) {
|
||||
int times = repeat_times_now[i];
|
||||
float logit_now = static_cast<float>(logits_now[i]);
|
||||
if (times != 0) {
|
||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logit_now = logit_now - times * beta - gamma;
|
||||
}
|
||||
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ban_bad_words(T *logits,
|
||||
const int64_t *bad_words_list,
|
||||
const int *output_padding_offset,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t bad_words_length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
for (int i = tid; i < bad_words_length; i += blockDim.x) {
|
||||
const int64_t bad_words_token_id = bad_words_list[i];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
void token_penalty_multi_scores_kernel(
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_score,
|
||||
const paddle::Tensor &presence_score,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto cu_stream = logits.stream();
|
||||
std::vector<int64_t> shape = logits.shape();
|
||||
auto repeat_times =
|
||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||
int64_t bs = seq_lens_this_time.shape()[0];
|
||||
int64_t token_num = shape[0];
|
||||
int64_t length = shape[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
|
||||
int block_size = (token_num + 32 - 1) / 32 * 32;
|
||||
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
end_length,
|
||||
max_seq_len);
|
||||
|
||||
block_size = (length_id + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
update_repeat_times<<<token_num, block_size, 0, cu_stream>>>(
|
||||
pre_ids.data<int64_t>(),
|
||||
cur_len.data<int64_t>(),
|
||||
repeat_times.data<int>(),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
max_seq_len);
|
||||
|
||||
block_size = (length + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
update_value_by_repeat_times<DataType_>
|
||||
<<<token_num, block_size, 0, cu_stream>>>(
|
||||
repeat_times.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(penalty_scores.data<data_t>())),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(frequency_score.data<data_t>())),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(presence_score.data<data_t>())),
|
||||
temperatures.data<float>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
max_seq_len);
|
||||
|
||||
block_size = (length_bad_words + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
bad_tokens.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
length_bad_words,
|
||||
max_seq_len);
|
||||
}
|
||||
|
||||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
switch (logits.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<
|
||||
paddle::DataType::BFLOAT16>(pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
||||
.Inputs({"pre_ids",
|
||||
"logits",
|
||||
"penalty_scores",
|
||||
"frequency_scores",
|
||||
"presence_scores",
|
||||
"temperatures",
|
||||
"bad_tokens",
|
||||
"cur_len",
|
||||
"min_len",
|
||||
"eos_token_id",
|
||||
"seq_lens_this_time",
|
||||
"output_padding_offset",
|
||||
"output_cum_offsets"})
|
||||
.Outputs({"logits_out"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.SetInplaceMap({{"logits", "logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
|
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void UpdateInputIdsCPU(const paddle::Tensor& input_ids_cpu,
|
||||
const std::vector<int64_t>& task_input_ids,
|
||||
const int bid,
|
||||
const int max_seq_len) {
|
||||
int64_t* input_ids_cpu_data =
|
||||
const_cast<int64_t*>(input_ids_cpu.data<int64_t>());
|
||||
// printf("Input len is %d\n", task_input_ids.size());
|
||||
|
||||
for (int i = 0; i < task_input_ids.size(); i++) {
|
||||
// printf("%lld\n", task_input_ids[i]);
|
||||
input_ids_cpu_data[bid * max_seq_len + i] = task_input_ids[i];
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_input_ids_cpu)
|
||||
.Inputs({"input_ids_cpu"})
|
||||
.Outputs({"input_ids_cpu_out"})
|
||||
.Attrs({"task_input_ids: std::vector<int64_t>",
|
||||
"bid: int",
|
||||
"max_seq_len: int"})
|
||||
.SetInplaceMap({{"input_ids_cpu", "input_ids_cpu_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputIdsCPU));
|
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
__global__ void update_this_time(int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int real_bsz,
|
||||
int value) {
|
||||
int linear_idx = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
|
||||
if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] != 0) {
|
||||
seq_lens_this_time[linear_idx] = value;
|
||||
} else if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] == 0) {
|
||||
seq_lens_this_time[linear_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateThisTime(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const int real_bsz,
|
||||
const int value) {
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
real_bsz,
|
||||
value);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_seq_lens_this_time)
|
||||
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_this_time_out"})
|
||||
.Attrs({"real_bsz: int", "value: int"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateThisTime));
|
146
custom_ops/gpu_ops/speculate_decoding/speculate_update_v2.cu
Normal file
146
custom_ops/gpu_ops/speculate_decoding/speculate_update_v2.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_update(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *not_need_stop,
|
||||
int64_t *draft_tokens,
|
||||
int *actual_draft_token_nums,
|
||||
const int64_t *accept_tokens,
|
||||
const int *accept_num,
|
||||
const bool *stop_flags,
|
||||
const int *seq_lens_this_time,
|
||||
const bool *is_block_step,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens) {
|
||||
const int bid = threadIdx.x;
|
||||
const int accept_num_now = accept_num[bid];
|
||||
int stop_flag_now_int = 0;
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
if (stop_flags[bid]) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
if (seq_lens_encoder[bid] == 0) {
|
||||
seq_lens_decoder[bid] += accept_num_now;
|
||||
}
|
||||
|
||||
if (seq_lens_this_time[bid] > 1 &&
|
||||
seq_lens_encoder[bid] ==
|
||||
0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft
|
||||
// token的数量
|
||||
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
|
||||
if (accept_num_now - 1 == current_actual_draft_token_num) {
|
||||
if (current_actual_draft_token_num + 2 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 2;
|
||||
} else if (current_actual_draft_token_num + 1 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 1;
|
||||
} else {
|
||||
actual_draft_token_nums[bid] = max_draft_tokens - 1;
|
||||
}
|
||||
} else {
|
||||
actual_draft_token_nums[bid] =
|
||||
actual_draft_token_nums[bid] - 1 >= 1
|
||||
? actual_draft_token_nums[bid] - 1
|
||||
: 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
seq_lens_decoder[bid] += seq_lens_encoder[bid];
|
||||
seq_lens_encoder[bid] = 0;
|
||||
}
|
||||
draft_tokens[bid * max_draft_tokens] =
|
||||
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
|
||||
if (stop_flag_now_int) {
|
||||
seq_lens_decoder[bid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < real_bsz;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateUpdateV2(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(actual_draft_token_nums.data<int>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
stop_flags.data<bool>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
is_block_step.data<bool>(),
|
||||
real_bsz,
|
||||
max_draft_tokens);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_v2)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
"draft_tokens",
|
||||
"actual_draft_token_nums",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"is_block_step"})
|
||||
.Outputs({"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"not_need_stop_out",
|
||||
"draft_tokens_out",
|
||||
"actual_draft_token_nums_out"})
|
||||
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdateV2));
|
155
custom_ops/gpu_ops/speculate_decoding/speculate_update_v3.cu
Normal file
155
custom_ops/gpu_ops/speculate_decoding/speculate_update_v3.cu
Normal file
@@ -0,0 +1,155 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_update_v3(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *not_need_stop,
|
||||
int64_t *draft_tokens,
|
||||
int *actual_draft_token_nums,
|
||||
const int64_t *accept_tokens,
|
||||
const int *accept_num,
|
||||
const bool *stop_flags,
|
||||
const int *seq_lens_this_time,
|
||||
const bool *is_block_step,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_draft_tokens) {
|
||||
const int bid = threadIdx.x;
|
||||
const int accept_num_now = accept_num[bid];
|
||||
int stop_flag_now_int = 0;
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
if (stop_flags[bid]) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
if (seq_lens_encoder[bid] == 0) {
|
||||
seq_lens_decoder[bid] += accept_num_now;
|
||||
}
|
||||
|
||||
if (seq_lens_this_time[bid] > 1 &&
|
||||
seq_lens_encoder[bid] ==
|
||||
0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft
|
||||
// token的数量
|
||||
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
|
||||
if (accept_num_now - 1 == current_actual_draft_token_num) {
|
||||
if (current_actual_draft_token_num + 2 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 2;
|
||||
} else if (current_actual_draft_token_num + 1 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 1;
|
||||
} else {
|
||||
actual_draft_token_nums[bid] = max_draft_tokens - 1;
|
||||
}
|
||||
} else {
|
||||
actual_draft_token_nums[bid] =
|
||||
actual_draft_token_nums[bid] - 1 >= 1
|
||||
? actual_draft_token_nums[bid] - 1
|
||||
: 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
seq_lens_decoder[bid] += seq_lens_encoder[bid];
|
||||
seq_lens_encoder[bid] = 0;
|
||||
}
|
||||
draft_tokens[bid * max_draft_tokens] =
|
||||
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
|
||||
if (stop_flag_now_int) {
|
||||
seq_lens_decoder[bid] = 0;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &stop_nums) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_update_v3<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(actual_draft_token_nums.data<int>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
stop_flags.data<bool>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
is_block_step.data<bool>(),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_draft_tokens);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_v3)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
"draft_tokens",
|
||||
"actual_draft_token_nums",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"is_block_step",
|
||||
"stop_nums"})
|
||||
.Outputs({"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"not_need_stop_out",
|
||||
"draft_tokens_out",
|
||||
"actual_draft_token_nums_out"})
|
||||
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdateV3));
|
478
custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu
Normal file
478
custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu
Normal file
@@ -0,0 +1,478 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
__device__ inline bool is_in(const int64_t *candidates,
|
||||
const int64_t draft,
|
||||
const int candidate_len) {
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
if (draft == candidates[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static uint64_t seed = 0;
|
||||
static uint64_t offset = 0;
|
||||
|
||||
__device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const float *candidate_scores,
|
||||
curandState_t *dev_curand_states,
|
||||
const int candidate_len,
|
||||
const float topp) {
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
if (rand_top_p <= sum_scores) {
|
||||
return candidate_ids[i];
|
||||
}
|
||||
}
|
||||
return candidate_ids[0];
|
||||
}
|
||||
|
||||
__global__ void setup_kernel(curandState_t *state,
|
||||
const uint64_t seed,
|
||||
const uint64_t offset,
|
||||
const int bs,
|
||||
const bool need_batch_random) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
|
||||
if (need_batch_random) {
|
||||
curand_init(seed, i, offset, &state[i]);
|
||||
} else {
|
||||
curand_init(seed, 0, offset, &state[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
__global__ void speculate_verify(int64_t *accept_tokens,
|
||||
int *accept_num,
|
||||
int64_t *step_idx,
|
||||
bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *draft_tokens,
|
||||
const int *actual_draft_token_nums,
|
||||
curandState_t *dev_curand_states,
|
||||
const float *topp,
|
||||
const int *seq_lens_this_time,
|
||||
const int64_t *verify_tokens,
|
||||
const float *verify_scores,
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *output_cum_offsets,
|
||||
const int *actual_candidate_len,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens,
|
||||
const int end_length,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const int bid = threadIdx.x;
|
||||
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
||||
// verify and set stop flags
|
||||
int accept_num_now = 1;
|
||||
int stop_flag_now_int = 0;
|
||||
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
// printf("bid %d\n", bid);
|
||||
|
||||
if (stop_flags[bid]) {
|
||||
stop_flag_now_int = 1;
|
||||
} else { // 这里prefill阶段也会进入,但是因为draft
|
||||
// tokens会置零,因此会直接到最后的采样阶段
|
||||
auto *verify_tokens_now =
|
||||
verify_tokens + start_token_id * max_candidate_len;
|
||||
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
||||
auto *actual_candidate_len_now =
|
||||
actual_candidate_len + start_token_id;
|
||||
|
||||
int i = 0;
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
// seq_lens_this_time[bid]-1);
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
break;
|
||||
}
|
||||
if (USE_TOPK) {
|
||||
if (verify_tokens_now[i * max_candidate_len] ==
|
||||
draft_tokens_now[i + 1]) {
|
||||
// accept_num_now++;
|
||||
step_idx[bid]++;
|
||||
auto accept_token = draft_tokens_now[i + 1];
|
||||
// printf("[USE_TOPK] bid %d Top 1 verify write accept
|
||||
// %d is %lld\n", bid, i, accept_token);
|
||||
accept_tokens[bid * max_draft_tokens + i] =
|
||||
accept_token;
|
||||
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||
step_idx[bid] >= max_dec_len[bid]) {
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
if (step_idx[bid] >= max_dec_len[bid])
|
||||
accept_tokens[bid * max_draft_tokens + i] =
|
||||
end_tokens[0];
|
||||
// printf("[USE_TOPK] bid %d Top 1 verify write
|
||||
// accept %d is %lld\n", bid, i, accept_token);
|
||||
break;
|
||||
} else {
|
||||
accept_num_now++;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
auto actual_candidate_len_value =
|
||||
actual_candidate_len_now[i] > max_candidate_len
|
||||
? max_candidate_len
|
||||
: actual_candidate_len_now[i];
|
||||
if (is_in(verify_tokens_now + i * max_candidate_len,
|
||||
draft_tokens_now[i + 1],
|
||||
actual_candidate_len_value)) {
|
||||
// Top P verify
|
||||
// accept_num_now++;
|
||||
step_idx[bid]++;
|
||||
auto accept_token = draft_tokens_now[i + 1];
|
||||
accept_tokens[bid * max_draft_tokens + i] =
|
||||
accept_token;
|
||||
|
||||
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||
step_idx[bid] >= max_dec_len[bid]) {
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
if (step_idx[bid] >= max_dec_len[bid])
|
||||
accept_tokens[bid * max_draft_tokens + i] =
|
||||
end_tokens[0];
|
||||
// printf("bid %d Top P verify write accept %d is
|
||||
// %lld\n", bid, i, accept_token);
|
||||
break;
|
||||
} else {
|
||||
accept_num_now++;
|
||||
}
|
||||
} else {
|
||||
// TopK verify
|
||||
int ii = i;
|
||||
if (max_candidate_len >= 2 &&
|
||||
verify_tokens_now[ii * max_candidate_len + 1] ==
|
||||
draft_tokens_now[ii + 1]) { // top-2
|
||||
int j = 0;
|
||||
ii += 1;
|
||||
for (; j < verify_window &&
|
||||
ii < seq_lens_this_time[bid] - 1;
|
||||
j++, ii++) {
|
||||
if (verify_tokens_now[ii * max_candidate_len] !=
|
||||
draft_tokens_now[ii + 1]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j >= verify_window) { // accept all
|
||||
accept_num_now += verify_window + 1;
|
||||
step_idx[bid] += verify_window + 1;
|
||||
for (; i < ii; i++) {
|
||||
auto accept_token = draft_tokens_now[i + 1];
|
||||
accept_tokens[bid * max_draft_tokens + i] =
|
||||
accept_token;
|
||||
// printf(
|
||||
// "bid %d TopK verify write accept %d
|
||||
// is "
|
||||
// "%lld\n",
|
||||
// bid,
|
||||
// i,
|
||||
// accept_token);
|
||||
if (is_in_end(accept_token,
|
||||
end_tokens,
|
||||
end_length) ||
|
||||
step_idx[bid] >= max_dec_len[bid]) {
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
if (step_idx[bid] >= max_dec_len[bid])
|
||||
accept_tokens[bid *
|
||||
max_draft_tokens +
|
||||
i] = end_tokens[0];
|
||||
// printf("bid %d TopK verify write
|
||||
// accept %d is %lld\n", bid, i,
|
||||
// end_tokens[0]);
|
||||
accept_num_now--;
|
||||
step_idx[bid]--;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// sampling阶段
|
||||
// 第一种,draft_token[i+1]被拒绝,需要从verify_tokens_now[i]中选一个
|
||||
// 第二种,i == seq_lens_this_time[bid]-1,
|
||||
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
|
||||
if (!stop_flag_now_int) {
|
||||
int64_t accept_token;
|
||||
const float *verify_scores_now =
|
||||
verify_scores + start_token_id * max_candidate_len;
|
||||
step_idx[bid]++;
|
||||
if (ENABLE_TOPP) {
|
||||
auto actual_candidate_len_value =
|
||||
actual_candidate_len_now[i] > max_candidate_len
|
||||
? max_candidate_len
|
||||
: actual_candidate_len_now[i];
|
||||
|
||||
accept_token = topp_sampling_kernel(
|
||||
verify_tokens_now + i * max_candidate_len,
|
||||
verify_scores_now + i * max_candidate_len,
|
||||
dev_curand_states,
|
||||
actual_candidate_len_value,
|
||||
topp[bid]);
|
||||
} else {
|
||||
accept_token = verify_tokens_now[i * max_candidate_len];
|
||||
}
|
||||
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[bid] = true;
|
||||
}
|
||||
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||
step_idx[bid] >= max_dec_len[bid]) {
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
if (step_idx[bid] >= max_dec_len[bid])
|
||||
accept_tokens[bid * max_draft_tokens + i] =
|
||||
end_tokens[0];
|
||||
}
|
||||
}
|
||||
accept_num[bid] = accept_num_now;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &verify_tokens,
|
||||
const paddle::Tensor &verify_scores,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &topp,
|
||||
int max_seq_len,
|
||||
int verify_window,
|
||||
bool enable_topp) {
|
||||
// printf("Enter speculate update\n");
|
||||
auto bsz = accept_tokens.shape()[0];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||
auto end_length = end_tokens.shape()[0];
|
||||
auto max_candidate_len = verify_tokens.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
curandState_t *dev_curand_states;
|
||||
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
|
||||
setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
dev_curand_states, seed, offset, bsz, true);
|
||||
seed++;
|
||||
offset++;
|
||||
|
||||
auto err = cudaDeviceSynchronize();
|
||||
if (err != 0) {
|
||||
printf("err %d\n", err);
|
||||
}
|
||||
|
||||
err = cudaGetLastError();
|
||||
|
||||
if (err != 0) {
|
||||
printf("err %d\n", err);
|
||||
}
|
||||
|
||||
// printf("inited curand\n");
|
||||
bool use_topk = false;
|
||||
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||
if (env_var) {
|
||||
use_topk = static_cast<bool>(std::stoi(env_var));
|
||||
}
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
if (use_topk) {
|
||||
// printf("use_topk \n");
|
||||
if (enable_topp) {
|
||||
speculate_verify<true, true>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
actual_draft_token_nums.data<int>(),
|
||||
dev_curand_states,
|
||||
topp.data<float>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
verify_scores.data<float>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
real_bsz,
|
||||
max_draft_tokens,
|
||||
end_length,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
} else {
|
||||
speculate_verify<false, true>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
actual_draft_token_nums.data<int>(),
|
||||
dev_curand_states,
|
||||
topp.data<float>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
verify_scores.data<float>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
real_bsz,
|
||||
max_draft_tokens,
|
||||
end_length,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
} else {
|
||||
if (enable_topp) {
|
||||
speculate_verify<true, false>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
actual_draft_token_nums.data<int>(),
|
||||
dev_curand_states,
|
||||
topp.data<float>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
verify_scores.data<float>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
real_bsz,
|
||||
max_draft_tokens,
|
||||
end_length,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
} else {
|
||||
speculate_verify<false, false>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
actual_draft_token_nums.data<int>(),
|
||||
dev_curand_states,
|
||||
topp.data<float>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
verify_scores.data<float>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
real_bsz,
|
||||
max_draft_tokens,
|
||||
end_length,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(dev_curand_states);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_verify)
|
||||
.Inputs({"accept_tokens",
|
||||
"accept_num",
|
||||
"step_idx",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"stop_flags",
|
||||
"draft_tokens",
|
||||
"seq_lens_this_time",
|
||||
"verify_tokens",
|
||||
"verify_scores",
|
||||
"max_dec_len",
|
||||
"end_tokens",
|
||||
"is_block_step",
|
||||
"output_cum_offsets",
|
||||
"actual_candidate_len",
|
||||
"actual_draft_token_nums",
|
||||
"topp"})
|
||||
.Outputs({"accept_tokens_out",
|
||||
"accept_num_out",
|
||||
"step_idx_out",
|
||||
"stop_flags_out"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"stop_flags", "stop_flags_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateVerify));
|
624
custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu
Normal file
624
custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu
Normal file
@@ -0,0 +1,624 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T
|
||||
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
|
||||
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync(
|
||||
unsigned mask, phi::dtype::float16 val, int delta, int width) {
|
||||
return paddle::float16(__shfl_down_sync(
|
||||
mask, val.to_half(), static_cast<unsigned>(delta), width));
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync(
|
||||
unsigned mask, phi::dtype::bfloat16 val, int delta, int width) {
|
||||
return paddle::bfloat16(__shfl_down_sync(
|
||||
mask, val.to_nv_bfloat16(), static_cast<unsigned>(delta), width));
|
||||
}
|
||||
|
||||
struct BlockPrefixCallbackOp {
|
||||
// Running prefix
|
||||
float running_total;
|
||||
// Constructor
|
||||
__device__ BlockPrefixCallbackOp(float running_total)
|
||||
: running_total(running_total) {}
|
||||
// Callback operator to be entered by the first warp of threads in the
|
||||
// block. Thread-0 is responsible for returning a value for seeding the
|
||||
// block-wide scan.
|
||||
__device__ float operator()(float block_aggregate) {
|
||||
float old_prefix = running_total;
|
||||
running_total += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
#define FINAL_MASK 0xFFFFFFFF
|
||||
|
||||
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
||||
case (dim): { \
|
||||
constexpr auto kBlockDim = (dim); \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
|
||||
#define FIXED_BLOCK_DIM(...) \
|
||||
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
|
||||
|
||||
#define FIXED_TOPK_BASE(topk, ...) \
|
||||
case (topk): { \
|
||||
constexpr auto kTopK = topk; \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
|
||||
#define FIXED_TOPK(...) \
|
||||
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(10, ##__VA_ARGS__)
|
||||
|
||||
struct SegmentOffsetIter {
|
||||
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
|
||||
|
||||
__host__ __device__ __forceinline__ int operator()(int idx) const {
|
||||
return idx * num_cols_;
|
||||
}
|
||||
|
||||
int num_cols_;
|
||||
};
|
||||
|
||||
inline int div_up(int a, int n) { return (a + n - 1) / n; }
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillIndex(T* indices, T num_rows, T num_cols) {
|
||||
int col_id = threadIdx.x;
|
||||
int row_id = blockIdx.x;
|
||||
|
||||
for (T j = row_id; j < num_rows; j += gridDim.x) {
|
||||
for (T i = col_id; i < num_cols; i += blockDim.x) {
|
||||
indices[j * num_cols + i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void SetCountIter(int* count_iter, int num) {
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
int idx = bid * blockDim.x + tid;
|
||||
for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
|
||||
count_iter[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
__global__ void top_p_candidates_kernel(T* sorted_probs,
|
||||
int64_t* sorted_id,
|
||||
T* out_val,
|
||||
int64_t* out_id,
|
||||
int* actual_candidates_lens,
|
||||
const int vocab_size,
|
||||
const float topp,
|
||||
const int candidates_len) {
|
||||
__shared__ int stop_shared;
|
||||
__shared__ float rand_p;
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
constexpr int NUM_WARPS = BLOCK_SIZE / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int warp_id = tid / 32;
|
||||
|
||||
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
|
||||
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
|
||||
__shared__ uint32_t selected_shared[NUM_WARPS];
|
||||
|
||||
if (lane_id == 0) {
|
||||
selected_shared[warp_id] = 0;
|
||||
}
|
||||
|
||||
// Initialize running total
|
||||
BlockPrefixCallbackOp prefix_op(0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int offset = bid * vocab_size;
|
||||
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int i_activate = 0;
|
||||
float thread_offset = 0;
|
||||
for (int i = tid; i < end; i += BLOCK_SIZE) {
|
||||
float thread_count = (i < vocab_size)
|
||||
? static_cast<float>(sorted_probs[offset + i])
|
||||
: 0.f;
|
||||
|
||||
BlockScan(temp_storage)
|
||||
.InclusiveSum(thread_count, thread_offset, prefix_op);
|
||||
|
||||
if (i < candidates_len) {
|
||||
out_id[bid * candidates_len + i] = sorted_id[offset + i];
|
||||
out_val[bid * candidates_len + i] = sorted_probs[offset + i];
|
||||
}
|
||||
|
||||
uint32_t activate_mask =
|
||||
__ballot_sync(FINAL_MASK, topp <= thread_offset);
|
||||
i_activate = i;
|
||||
if (activate_mask != 0 || i >= candidates_len) {
|
||||
if (lane_id == 0) {
|
||||
atomicAdd(&stop_shared, 1);
|
||||
selected_shared[warp_id] = activate_mask;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (stop_shared > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
bool skip = (selected_shared[warp_id] > 0) ? false : true;
|
||||
for (int i = 0; i < warp_id; i++) {
|
||||
if (selected_shared[i] != 0) {
|
||||
// If the previous has stopped, skip the current warp
|
||||
skip = true;
|
||||
}
|
||||
}
|
||||
if (!skip) {
|
||||
int active_lane_id =
|
||||
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
|
||||
if (lane_id == active_lane_id) {
|
||||
actual_candidates_lens[bid] = i_activate + 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
// printf("actual_candidates_lens[%d] %d\n", bid,
|
||||
// actual_candidates_lens[bid]);
|
||||
if (actual_candidates_lens[bid] == 0) {
|
||||
actual_candidates_lens[bid] = candidates_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Pair {
|
||||
__device__ __forceinline__ Pair() {}
|
||||
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
|
||||
|
||||
__device__ __forceinline__ void set(T value, int id) {
|
||||
this->v = value;
|
||||
this->id = id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void operator=(const Pair<T>& in) {
|
||||
v = in.v;
|
||||
id = in.id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool operator<(const T value) const {
|
||||
return (static_cast<float>(v) < static_cast<float>(value));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool operator>(const T value) const {
|
||||
return (static_cast<float>(v) > static_cast<float>(value));
|
||||
}
|
||||
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
|
||||
return (static_cast<float>(v) < static_cast<float>(in.v)) ||
|
||||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
|
||||
(id > in.id));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
|
||||
return (static_cast<float>(v) > static_cast<float>(in.v)) ||
|
||||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
|
||||
(id < in.id));
|
||||
}
|
||||
|
||||
T v;
|
||||
int id;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void AddTo(Pair<T> topk[],
|
||||
const Pair<T>& p,
|
||||
int beam_size) {
|
||||
for (int k = beam_size - 2; k >= 0; k--) {
|
||||
if (topk[k] < p) {
|
||||
topk[k + 1] = topk[k];
|
||||
} else {
|
||||
topk[k + 1] = p;
|
||||
return;
|
||||
}
|
||||
}
|
||||
topk[0] = p;
|
||||
}
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
__device__ __forceinline__ void GetTopK(
|
||||
Pair<T> topk[], const T* src, int idx, int dim, int beam_size) {
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
AddTo<T>(topk, tmp, beam_size);
|
||||
}
|
||||
idx += BlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
|
||||
const T* src,
|
||||
int idx,
|
||||
int dim,
|
||||
const Pair<T>& max,
|
||||
int beam_size) {
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
if (tmp < max) {
|
||||
AddTo<T>(topk, tmp, beam_size);
|
||||
}
|
||||
}
|
||||
idx += BlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int BlockSize>
|
||||
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
|
||||
int* beam,
|
||||
int beam_size,
|
||||
const T* src,
|
||||
bool* firstStep,
|
||||
bool* is_empty,
|
||||
Pair<T>* max,
|
||||
int dim,
|
||||
const int tid) {
|
||||
if (*beam > 0) {
|
||||
int length = (*beam) < beam_size ? *beam : beam_size;
|
||||
if (*firstStep) {
|
||||
*firstStep = false;
|
||||
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
|
||||
} else {
|
||||
for (int k = 0; k < MaxLength; k++) {
|
||||
if (k < MaxLength - (*beam)) {
|
||||
topk[k] = topk[k + *beam];
|
||||
} else {
|
||||
topk[k].set(std::numeric_limits<T>::min(), -1);
|
||||
}
|
||||
}
|
||||
if (!(*is_empty)) {
|
||||
GetTopK<T, BlockSize>(
|
||||
topk + MaxLength - *beam, src, tid, dim, *max, length);
|
||||
}
|
||||
}
|
||||
|
||||
*max = topk[MaxLength - 1];
|
||||
if ((*max).id == -1) *is_empty = true;
|
||||
*beam = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset);
|
||||
int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset);
|
||||
if (static_cast<float>(input.v) < static_cast<float>(tmp_val)) {
|
||||
input.v = tmp_val;
|
||||
input.id = tmp_id;
|
||||
}
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int BlockSize>
|
||||
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
|
||||
Pair<T> topk[],
|
||||
Pair<T> beam_max[],
|
||||
int* beam,
|
||||
int* k,
|
||||
int* count,
|
||||
const int tid,
|
||||
const int wid,
|
||||
const int lane) {
|
||||
while (true) {
|
||||
__syncthreads();
|
||||
Pair<T> input_now = topk[0];
|
||||
input_now = WarpReduce(input_now);
|
||||
|
||||
if (lane == 0) {
|
||||
shared_max[wid] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
input_now = (tid < BlockSize / 32)
|
||||
? shared_max[lane]
|
||||
: Pair<T>(std::numeric_limits<T>::min(), -1);
|
||||
if (wid == 0) {
|
||||
input_now = WarpReduce(input_now);
|
||||
if (lane == 0) shared_max[0] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
beam_max[*count] = shared_max[0];
|
||||
(*count)++;
|
||||
}
|
||||
int tid_max = shared_max[0].id % BlockSize;
|
||||
if (tid == tid_max) {
|
||||
(*beam)++;
|
||||
}
|
||||
if (--(*k) == 0) break;
|
||||
__syncthreads();
|
||||
|
||||
if (tid == tid_max) {
|
||||
if (*beam < MaxLength) {
|
||||
topk[0] = topk[*beam];
|
||||
}
|
||||
}
|
||||
|
||||
if (MaxLength < 5) {
|
||||
if (*beam >= MaxLength) break;
|
||||
} else {
|
||||
unsigned mask = 0u;
|
||||
mask = __ballot_sync(FINAL_MASK, true);
|
||||
if (tid_max / 32 == wid) {
|
||||
if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) ==
|
||||
MaxLength)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
|
||||
__global__ void KeMatrixTopPBeamTopKFt(
|
||||
const T* src,
|
||||
const T* top_ps,
|
||||
const int* output_padding_offset,
|
||||
int64_t* out_id, // [max_cadidate_len, 1]
|
||||
T* out_val, // [max_cadidate_len, 1]
|
||||
int* actual_candidates_lens,
|
||||
int vocab_size,
|
||||
const int max_cadidate_len,
|
||||
const int max_seq_len) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + output_padding_offset[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
int top_num = TopPBeamTopK;
|
||||
float top_p_value = static_cast<float>(top_ps[bid]);
|
||||
|
||||
__shared__ Pair<T> shared_max[BlockSize / 32];
|
||||
__shared__ Pair<T> beam_max[TopPBeamTopK];
|
||||
|
||||
Pair<T> topk[MaxLength];
|
||||
int beam = MaxLength;
|
||||
Pair<T> max;
|
||||
bool is_empty = false;
|
||||
bool firststep = true;
|
||||
__shared__ int count;
|
||||
|
||||
if (tid == 0) {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
for (int j = 0; j < MaxLength; j++) {
|
||||
topk[j].set(std::numeric_limits<T>::min(), -1);
|
||||
}
|
||||
|
||||
while (top_num) {
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
|
||||
&beam,
|
||||
TopPBeamTopK,
|
||||
src + token_id * vocab_size,
|
||||
&firststep,
|
||||
&is_empty,
|
||||
&max,
|
||||
vocab_size,
|
||||
tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(shared_max,
|
||||
topk,
|
||||
beam_max,
|
||||
&beam,
|
||||
&top_num,
|
||||
&count,
|
||||
tid,
|
||||
wid,
|
||||
lane);
|
||||
}
|
||||
if (tid == 0) {
|
||||
float sum_prob = 0.0f;
|
||||
bool flag = false;
|
||||
for (int i = 0; i < TopPBeamTopK; i++) {
|
||||
out_id[token_id * max_cadidate_len + i] =
|
||||
static_cast<int64_t>(beam_max[i].id);
|
||||
out_val[token_id * max_cadidate_len + i] = beam_max[i].v;
|
||||
float val = static_cast<float>(beam_max[i].v);
|
||||
sum_prob += val;
|
||||
|
||||
if (sum_prob >= top_p_value) {
|
||||
actual_candidates_lens[token_id] = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TopKMaxLength>
|
||||
void DispatchTopK(const T* src,
|
||||
const T* top_ps,
|
||||
const int* output_padding_offset,
|
||||
int64_t* out_id, // topk id
|
||||
T* out_val, // topk val
|
||||
int* actual_candidates_lens_data,
|
||||
const int vocab_size,
|
||||
const int token_num,
|
||||
const int cadidate_len,
|
||||
const int max_seq_len,
|
||||
const cudaStream_t& stream) {
|
||||
int BlockSize = GetBlockSize(vocab_size);
|
||||
switch (cadidate_len) {
|
||||
FIXED_TOPK(switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(
|
||||
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, kTopK, kBlockDim>
|
||||
<<<token_num, kBlockDim, 0, stream>>>(
|
||||
src,
|
||||
top_ps,
|
||||
output_padding_offset,
|
||||
out_id,
|
||||
out_val,
|
||||
actual_candidates_lens_data,
|
||||
vocab_size,
|
||||
cadidate_len,
|
||||
max_seq_len));
|
||||
default:
|
||||
PD_THROW(
|
||||
"the input data shape has error in the topp_beam_topk "
|
||||
"kernel.");
|
||||
});
|
||||
default:
|
||||
PD_THROW("the input topk is not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> LaunchTopPCandidates(
|
||||
const paddle::Tensor& probs, // [token_num, vocab_size]
|
||||
const paddle::Tensor& top_p, // [token_num]
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int candidates_len,
|
||||
const int max_seq_len) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
std::vector<int64_t> input_shape = probs.shape();
|
||||
const int token_num = input_shape[0];
|
||||
const int vocab_size = input_shape[1];
|
||||
|
||||
auto verify_scores =
|
||||
paddle::full({token_num, candidates_len}, 0, D, probs.place());
|
||||
auto verify_tokens = paddle::full(
|
||||
{token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place());
|
||||
auto actual_candidate_lens =
|
||||
paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place());
|
||||
|
||||
auto stream = probs.stream();
|
||||
|
||||
constexpr int TopKMaxLength = 2;
|
||||
DispatchTopK<DataType_, TopKMaxLength>(
|
||||
reinterpret_cast<const DataType_*>(probs.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(top_p.data<data_t>()),
|
||||
output_padding_offset.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
reinterpret_cast<DataType_*>(verify_scores.data<data_t>()),
|
||||
actual_candidate_lens.data<int>(),
|
||||
vocab_size,
|
||||
token_num,
|
||||
candidates_len,
|
||||
max_seq_len,
|
||||
stream);
|
||||
|
||||
return {verify_scores, verify_tokens, actual_candidate_lens};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
|
||||
const paddle::Tensor& probs,
|
||||
const paddle::Tensor& top_p,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
int candidates_len,
|
||||
int max_seq_len) {
|
||||
switch (probs.type()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::BFLOAT16>(
|
||||
probs,
|
||||
top_p,
|
||||
output_padding_offset,
|
||||
candidates_len,
|
||||
max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT16>(
|
||||
probs,
|
||||
top_p,
|
||||
output_padding_offset,
|
||||
candidates_len,
|
||||
max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT32>(
|
||||
probs,
|
||||
top_p,
|
||||
output_padding_offset,
|
||||
candidates_len,
|
||||
max_seq_len);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only bfloat16, float16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> TopPCandidates(
|
||||
const paddle::Tensor& probs,
|
||||
const paddle::Tensor& top_p,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
int candidates_len,
|
||||
int max_seq_len) {
|
||||
return DispatchTopPCandidatesWithDtype(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
||||
const std::vector<int64_t>& probs_shape,
|
||||
const std::vector<int64_t>& top_p_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape,
|
||||
int max_candidates_len) {
|
||||
int token_num = probs_shape[0];
|
||||
return {{token_num, max_candidates_len},
|
||||
{token_num, max_candidates_len},
|
||||
{token_num}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> TopPCandidatesInferDtype(
|
||||
const paddle::DataType& probs_dtype,
|
||||
const paddle::DataType& top_p_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype) {
|
||||
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(top_p_candidates)
|
||||
.Inputs({"probs", "top_p", "output_padding_offset"})
|
||||
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
|
||||
.Attrs({"candidates_len: int", "max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(TopPCandidates))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(TopPCandidatesInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(TopPCandidatesInferDtype));
|
Reference in New Issue
Block a user