[LLM] First commit the llm deployment code

This commit is contained in:
jiangjiajun
2025-06-09 19:20:15 +08:00
parent 980c0a1d2c
commit 684703fd72
11814 changed files with 127294 additions and 1293102 deletions

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void draft_model_update_seq_lens_this_time_kernel(
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
int tid = threadIdx.x;
if (tid < bsz) {
if (!base_model_stop_flags[tid] &&
base_model_seq_lens_encoder[tid] == 0) {
const int64_t* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_token_len;
int token_num = 0;
for (int i = 0; i < base_model_draft_token_len; ++i) {
if (base_model_draft_tokens_now[i] != -1) {
token_num++;
}
}
base_model_seq_lens_this_time[tid] = token_num;
} else if (base_model_stop_flags[tid]) {
base_model_seq_lens_this_time[tid] = 0;
}
}
}
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags) {
int real_bsz = base_model_seq_lens_this_time.shape()[0];
auto cu_stream = base_model_seq_lens_this_time.stream();
int block_size = (real_bsz + 32 - 1) / 32 * 32;
int base_model_draft_token_len = base_model_draft_tokens.shape()[1];
draft_model_update_seq_lens_this_time_kernel<<<1,
block_size,
0,
cu_stream>>>(
base_model_draft_tokens.data<int64_t>(),
const_cast<int*>(base_model_seq_lens_this_time.data<int>()),
base_model_seq_lens_encoder.data<int>(),
base_model_stop_flags.data<bool>(),
real_bsz,
base_model_draft_token_len);
}
PD_BUILD_STATIC_OP(draft_model_postprocess)
.Inputs({"base_model_draft_tokens",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",
"base_model_stop_flags"})
.Outputs({"base_model_draft_tokens_out",
"base_model_seq_lens_this_time_out",
"base_model_stop_flags_out"})
.SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"},
{"base_model_seq_lens_this_time",
"base_model_seq_lens_this_time_out"},
{"base_model_stop_flags", "base_model_stop_flags_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPostprocess));

View File

@@ -0,0 +1,500 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
__global__ void process_splitwise_prefill(
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
int tid = threadIdx.x;
if (tid < bsz) {
int base_model_step_idx_now = base_model_step_idx[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
not_stop_flag = 1;
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
not_stop_flag = 0;
}
}
__syncthreads();
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
if (tid == 0) {
not_need_stop[0] = not_stop_flag_sum > 0;
}
}
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
__global__ void draft_model_preprocess_kernel(
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
int tid = threadIdx.x;
if (tid < bsz) {
auto base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
auto accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len;
#pragma unroll
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
}
// 处理 base_model recover 逻辑
// 1. 已处于 recover 状态
// if (batch_drop[tid]) {
// }
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
if (base_model_step_idx_now == 0) {
seq_lens_this_time[tid] = 0;
not_stop_flag = 0;
} else if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
seq_lens_decoder_record[tid] = 0;
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
}
} else if (accept_num_now <=
max_draft_token) /*Accept partial draft tokens*/ {
// Base Model reject stop
if (stop_flags[tid]) {
stop_flags[tid] = false;
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
step_idx[tid] = base_model_step_idx[tid];
} else {
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
step_idx[tid] -= max_draft_token - accept_num_now;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
draft_tokens_now[0] = modified_token;
seq_lens_this_time[tid] = 1;
} else /*Accept all draft tokens*/ {
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
seq_lens_this_time[tid] = 2;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
}
}
__syncthreads();
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
if (tid == 0) {
not_need_stop[0] = not_stop_flag_sum > 0;
}
}
template <bool TRCUNCATE_FIRST_TOKEN>
void DispatchRunner(
const cudaStream_t& stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const bool splitwise_prefill) {
constexpr int BlockSize = 512;
if (splitwise_prefill) {
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len);
}
}
void DispatchTokenMode(
const cudaStream_t &stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const bool truncate_first_token,
const bool splitwise_prefill) {
if (truncate_first_token) {
DispatchRunner<true>(
stream,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
splitwise_prefill
);
} else {
DispatchRunner<false>(
stream,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
splitwise_prefill
);
}
}
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill) {
int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
int draft_tokens_len = draft_tokens.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
constexpr int BlockSize = 512;
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
DispatchTokenMode(
cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(seq_lens_encoder_record.data<int>()),
const_cast<int*>(seq_lens_decoder_record.data<int>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
base_model_stop_flags.data<bool>(),
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
truncate_first_token,
splitwise_prefill);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_preprocess)
.Inputs({"draft_tokens",
"input_ids",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"seq_lens_encoder_record",
"seq_lens_decoder_record",
"not_need_stop",
"batch_drop",
"accept_tokens",
"accept_num",
"base_model_seq_lens_encoder",
"base_model_seq_lens_decoder",
"base_model_step_idx",
"base_model_stop_flags",
"base_model_is_block_step",
"base_model_draft_tokens"})
.Outputs({"draft_tokens_out",
"input_ids_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_idx_out",
"not_need_stop_out",
"batch_drop_out",
"seq_lens_encoder_record_out",
"seq_lens_decoder_record_out"})
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"},
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void update_pre_ids_kernel(const int64_t* draft_tokens,
int64_t* pre_ids_all,
const bool* stop_flags,
int* seq_lens_this_time,
const int64_t* step_idx,
int bs,
int pre_id_length,
int max_draft_token) {
int tid = threadIdx.x;
if (tid < bs && seq_lens_this_time[tid] != 0 && !stop_flags[tid]) {
int64_t* pre_ids_all_now = pre_ids_all + tid * pre_id_length;
const int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
const int seq_len_this_time = seq_lens_this_time[tid];
if (step_idx[tid] - 1 > 0 /*Decoder Step*/) {
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_all_now[step_idx[tid] - i] =
draft_token_now[seq_len_this_time - 1 - i];
}
} else if (step_idx[tid] == 1 /*Encoder Step*/) {
pre_ids_all_now[1] = draft_token_now[0];
}
seq_lens_this_time[tid] = 1;
}
}
void SpeculateDraftModelUpdate(const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids_all,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx) {
int64_t real_bs = seq_lens_this_time.shape()[0];
int64_t pre_id_length = pre_ids_all.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
int64_t max_draft_token = draft_tokens.shape()[1];
int block_size = (real_bs + 32 - 1) / 32 * 32;
update_pre_ids_kernel<<<1, block_size, 0, cu_stream>>>(
draft_tokens.data<int64_t>(),
const_cast<int64_t*>(pre_ids_all.data<int64_t>()),
stop_flags.data<bool>(),
const_cast<int*>(seq_lens_this_time.data<int>()),
step_idx.data<int64_t>(),
real_bs,
pre_id_length,
max_draft_token);
}
PD_BUILD_STATIC_OP(draft_model_set_value_by_flags)
.Inputs({"draft_tokens",
"pre_ids_all",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.SetKernelFn(PD_KERNEL(SpeculateDraftModelUpdate));

View File

@@ -0,0 +1,215 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <int THREADBLOCK_SIZE>
__global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t stop_flag_now_int = 0;
int tid = threadIdx.x;
if (tid < bsz) {
auto* draft_token_now = draft_tokens + tid * max_draft_token;
auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id =
tid * max_seq_len - output_cum_offsets[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
auto seq_len_decoder = seq_lens_decoder[tid];
// 1. update step_idx && seq_lens_dec
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
int64_t token_this_time = -1;
// decoder step
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
}
step_idx[tid] += seq_len_this_time;
} else {
token_this_time = next_tokens_start[0];
// seq_lens_decoder[tid] = seq_lens_encoder[tid];
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
seq_lens_encoder[tid] = 0;
pre_ids_now[1] = token_this_time;
step_idx[tid] += 1;
draft_token_now[0] = token_this_time;
base_model_draft_tokens_now[substep + 1] = token_this_time;
}
// multi_end
if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) {
stop_flags[tid] = true;
stop_flag_now_int = 1;
// max_dec_len
} else if (step_idx[tid] >= max_dec_len[tid]) {
stop_flags[tid] = true;
draft_token_now[seq_len_this_time - 1] = end_ids[0];
base_model_draft_tokens_now[substep + 1] = end_ids[0];
stop_flag_now_int = 1;
}
} else {
draft_token_now[0] = -1;
base_model_draft_tokens_now[substep + 1] = -1;
stop_flag_now_int = 1;
}
// 2. set end
if (!stop_flags[tid]) {
seq_lens_this_time[tid] = 1;
} else {
seq_lens_this_time[tid] = 0;
seq_lens_encoder[tid] = 0;
}
}
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (tid == 0) {
not_need_stop[0] = stop_sum < bsz;
}
}
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep) {
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
auto cu_stream = seq_lens_this_time.stream();
const int real_bsz = seq_lens_this_time_shape[0];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
const int end_ids_len = end_ids.shape()[0];
const int max_draft_token = draft_tokens.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
const int max_base_model_draft_token = base_model_draft_tokens.shape()[1];
constexpr int BlockSize = 512;
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
draft_model_update_kernel<BlockSize><<<1, BlockSize, 0, cu_stream>>>(
inter_next_tokens.data<int64_t>(),
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
output_cum_offsets.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()),
not_need_stop_gpu.data<bool>(),
max_dec_len.data<int64_t>(),
end_ids.data<int64_t>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_update)
.Inputs({"inter_next_tokens",
"draft_tokens",
"pre_ids",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"output_cum_offsets",
"stop_flags",
"not_need_stop",
"max_dec_len",
"end_ids",
"base_model_draft_tokens"})
.Attrs({"max_seq_len: int", "substep: int"})
.Outputs({"draft_tokens_out",
"pre_ids_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_idx_out",
"stop_flags_out",
"not_need_stop_out",
"base_model_draft_tokens_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"pre_ids", "pre_ids_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"stop_flags", "stop_flags_out"},
{"not_need_stop", "not_need_stop_out"},
{"base_model_draft_tokens", "base_model_draft_tokens_out"}})
.SetKernelFn(PD_KERNEL(DraftModelUpdate));

View File

@@ -0,0 +1,240 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
// #define DEBUG_EAGLE_KERNEL
__global__ void ComputeOrderKernel(
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
int in_offset = 0; // input_offset(long)
int out_offset = 0; // output_offset(short)
if (threadIdx.x == 0) {
for (int i = 0; i < bsz; ++i) {
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
int cur_seq_lens_this_time = seq_lens_this_time[i];
int accept_num = accept_nums[i];
int cur_seq_lens_encoder = seq_lens_encoder[i];
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_base_model_seq_lens_this_time%d. cur_seq_lens_this_time%d, accept_num %d\n", i, cur_base_model_seq_lens_this_time, cur_seq_lens_this_time, accept_num);
#endif
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
#endif
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// 3. New end
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: base=0. draft !=0 \n", i);
#endif
in_offset += cur_base_model_seq_lens_this_time;
// 4. stopped
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
} else {
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num > actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
}
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", position_map[i]);
}
printf("\n");
#endif
}
}
template <typename T, int VecSize>
__global__ void rebuildHiddenStatesKernel(
const T* input,
const int* position_map,
T* out,
const int dim_embed,
const int elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt; elem_idx += blockDim.x * gridDim.x * VecSize) {
int ori_token_idx = elem_idx / dim_embed;
int token_idx = position_map[ori_token_idx];
if (token_idx >= 0) {
int offset = elem_idx % dim_embed;
if (token_idx == 0) {
}
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto input_token_num = input.shape()[0];
// auto output_token_num = padding_offset.shape()[0];
auto dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto position_map = paddle::empty({input_token_num}, seq_lens_this_time.dtype(), input.place());
cudaMemsetAsync(position_map.data<int>(), 0xFF, input_token_num * sizeof(seq_lens_this_time.dtype()), seq_lens_this_time.stream());
auto output_token_num = paddle::empty({1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
ComputeOrderKernel<<<1,1>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
accept_nums.data<int>(),
position_map.data<int>(),
output_token_num.data<int>(),
bsz,
actual_draft_token_num,
input_token_num);
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::empty({output_token_num_cpu, dim_embed}, input.dtype(), input.place());
constexpr int packSize = VEC_16B / (sizeof(DataType_));
int elem_cnt = input_token_num * dim_embed;
assert(elem_cnt % packSize == 0);
int pack_num = elem_cnt / packSize;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
constexpr int thread_per_block = 128;
rebuildHiddenStatesKernel<DataType_, packSize><<<grid_size, thread_per_block>>>(
reinterpret_cast<const DataType_*>(input.data<data_t>()),
position_map.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
dim_embed,
elem_cnt);
return {out};
}
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num) {
switch (input.dtype()) {
case paddle::DataType::FLOAT16: {
return DispatchDtype<paddle::DataType::FLOAT16>(
input,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
stop_flags,
accept_nums,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
actual_draft_token_num);
}
case paddle::DataType::BFLOAT16: {
return DispatchDtype<paddle::DataType::BFLOAT16>(
input,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
stop_flags,
accept_nums,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
actual_draft_token_num);
}
default: {
PD_THROW("Not support this data type");
}
}
}
PD_BUILD_STATIC_OP(eagle_get_hidden_states)
.Inputs({"input",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"stop_flags",
"accept_nums",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder"})
.Attrs({"actual_draft_token_num: int"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));

View File

@@ -0,0 +1,190 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
// #define DEBUG_EAGLE_KERNEL
__global__ void computeOrderKernel(
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
int in_offset = 0;
int out_offset = 0;
if (threadIdx.x == 0) {
for (int i = 0; i < bsz; ++i) {
int cur_seq_lens_this_time = seq_lens_this_time[i];
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_seq_lens_this_time:%d. cur_last_seq_lens_this_time:%d\n", i, cur_seq_lens_this_time, cur_last_seq_lens_this_time);
#endif
// 1. encoder
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d last_step is encoder \n", i);
#endif
in_offset += 1;
src_map[out_offset++] = in_offset - 1;
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finish. src_map[%d]=%d \n", i, out_offset - 1, in_offset - 1);
#endif
// 2. decoder
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d is decoder\n", i);
#endif
in_offset += cur_last_seq_lens_this_time;
src_map[out_offset++] = in_offset - 1;
// 3. stop
} else {
// first token end
if (step_idx[i] == 1) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finished in first token \n", i);
#endif
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
// normal end
} else {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finished in non-first token \n", i);
#endif
in_offset += cur_last_seq_lens_this_time;
}
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", src_map[i]);
}
printf("\n");
#endif
}
}
template<typename T, int PackSize>
__global__ void rebuildSelfHiddenStatesKernel(
const T* input,
int* src_map,
T* output,
int dim_embed,
int elem_cnt) {
using LoadT = AlignedVector<T, PackSize>;
LoadT src_vec;
int global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int elem_id = global_thread_idx * PackSize; elem_id < elem_cnt; elem_id += blockDim.x * gridDim.x * PackSize) {
int output_token_idx = elem_id / dim_embed;
int input_token_idx = src_map[output_token_idx];
int offset = elem_id % dim_embed;
Load<T, PackSize>(input + input_token_idx * dim_embed + offset, &src_vec);
Store<T, PackSize>(src_vec, output + output_token_idx * dim_embed + offset);
}
}
template<paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor input,
const paddle::Tensor last_seq_lens_this_time,
const paddle::Tensor seq_lens_this_time,
const paddle::Tensor step_idx) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int input_token_num = input.shape()[0];
int dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto src_map = paddle::full({input_token_num}, -1, seq_lens_this_time.dtype(), seq_lens_this_time.place());
auto output_token_num = paddle::full({1}, 0, seq_lens_this_time.dtype(), seq_lens_this_time.place());
computeOrderKernel<<<1, 1, 0, seq_lens_this_time.stream()>>>(
last_seq_lens_this_time.data<int>(),
seq_lens_this_time.data<int>(),
step_idx.data<int64_t>(),
src_map.data<int>(),
output_token_num.data<int>(),
bsz);
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::full({output_token_num_cpu, dim_embed}, -1, input.type(), input.place());
constexpr int packSize = VEC_16B / (sizeof(DataType_));
int elem_cnt = output_token_num_cpu * dim_embed;
// printf("output_token_num: %d, dim_embed: %d, cnt: %d. packSize: %d\n", output_token_num_cpu, dim_embed,elem_cnt, packSize);
assert(elem_cnt % packSize == 0);
int pack_num = elem_cnt / packSize;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
constexpr int threadPerBlock = 128;
rebuildSelfHiddenStatesKernel<DataType_, packSize><<<grid_size, threadPerBlock, 0, input.stream()>>>(
reinterpret_cast<const DataType_*>(input.data<data_t>()),
src_map.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
dim_embed,
elem_cnt);
return {out};
}
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx) {
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
input,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx);
case paddle::DataType::FLOAT16:
return DispatchDtype<paddle::DataType::FLOAT16>(
input,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx);
default:
PD_THROW("Not support this data type");
}
}
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
.Inputs({"input",
"last_seq_lens_this_time",
"seq_lens_this_time",
"step_idx"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));

View File

@@ -0,0 +1,136 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, int VecSize>
__global__ void HydraFetchHiddenStatesKernel(const T* output_hidden_states,
const int* output_padding_offset,
const int* accept_token_num,
T* hidden_states,
const int bsz,
const int max_seq_len,
const int hidden_size) {
const int token_id = blockIdx.x;
const int ori_token_id = token_id + output_padding_offset[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_ori_token_id = bid * max_seq_len;
const int local_token_id = ori_token_id - start_ori_token_id;
if (local_token_id != accept_token_num[bid] - 1) return;
using LoadT = AlignedVector<T, VecSize>;
LoadT vec;
for (int idx = threadIdx.x * VecSize; idx < hidden_size;
idx += blockDim.x * VecSize) {
Load(&output_hidden_states[token_id * hidden_size + idx], &vec);
Store(vec, &hidden_states[bid * hidden_size + idx]);
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> HydraFetchHiddenStatesImpl(
const paddle::Tensor& output_hidden_states,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& accept_token_num,
const int bsz,
const int max_seq_length) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto cu_stream = output_hidden_states.stream();
auto output_token_num = output_hidden_states.shape()[0];
auto hidden_size = output_hidden_states.shape()[1];
auto hidden_states = paddle::full({bsz, hidden_size},
0,
output_hidden_states.dtype(),
output_hidden_states.place());
constexpr int VecSize = 16 / sizeof(data_t);
HydraFetchHiddenStatesKernel<data_t, VecSize>
<<<output_token_num, 256, 0, cu_stream>>>(
output_hidden_states.data<data_t>(),
output_padding_offset.data<int>(),
accept_token_num.data<int>(),
hidden_states.data<data_t>(),
bsz,
max_seq_length,
hidden_size);
return {hidden_states}; // , enc_token_num, dec_token_num};
}
std::vector<paddle::Tensor> HydraFetchHiddenStates(
const paddle::Tensor& output_hidden_states,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& accept_token_num,
const int bsz,
const int max_seq_length) {
switch (output_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16: {
return HydraFetchHiddenStatesImpl<paddle::DataType::BFLOAT16>(
output_hidden_states,
output_padding_offset,
accept_token_num,
bsz,
max_seq_length);
}
case paddle::DataType::FLOAT16: {
return HydraFetchHiddenStatesImpl<paddle::DataType::FLOAT16>(
output_hidden_states,
output_padding_offset,
accept_token_num,
bsz,
max_seq_length);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> HydraFetchHiddenStatesInferShape(
const std::vector<int64_t>& output_hidden_states_shape,
const std::vector<int64_t>& output_padding_offset_shape,
const std::vector<int64_t>& accept_token_num_shape,
const int bsz,
const int max_seq_length) {
return {{bsz, output_hidden_states_shape[1]}};
}
std::vector<paddle::DataType> HydraFetchHiddenStatesInferDtype(
const paddle::DataType& output_hidden_states_dtype,
const paddle::DataType& output_padding_offset_dtype,
const paddle::DataType& accept_token_num_dtype) {
return {output_hidden_states_dtype};
}
PD_BUILD_STATIC_OP(hydra_fetch_hidden_states)
.Inputs({"output_hidden_states",
"output_padding_offset",
"accept_token_num"})
.Outputs({"hidden_states"})
.Attrs({"bsz: int", "max_seq_length,: int"})
.SetKernelFn(PD_KERNEL(HydraFetchHiddenStates))
.SetInferShapeFn(PD_INFER_SHAPE(HydraFetchHiddenStatesInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(HydraFetchHiddenStatesInferDtype));

View File

@@ -0,0 +1,160 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype;
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
};
void MTPSaveFirstToken(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[2*i] << " ";
std::cout << " " << (int)x_data[2*i + 1];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
printf("full msg buffer\n");
}
return;
}
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
}
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_STATIC_OP(mtp_save_first_token)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t",
"save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));

View File

@@ -0,0 +1,226 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
template<int NUM_THREADS, int MAX_BATCH_SIZE=256>
__global__ void mtp_free_and_dispatch_block(
bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens) {
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, NUM_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ int need_block_len;
__shared__ int need_block_list[MAX_BATCH_SIZE];
const int tid = threadIdx.x;
if (tid < bsz) {
if (tid == 0) {
need_block_len = 0;
}
need_block_list[tid] = 0;
int *block_table_now = block_tables + tid * block_num_per_seq;
if (base_model_stop_flags[tid] || batch_drop[tid]) {
// 回收block块
const int encoder_block_len = encoder_block_lens[tid];
const int decoder_used_len = used_list_len[tid];
if (decoder_used_len > 0) {
const int ori_free_list_len =
atomicAdd(free_list_len, decoder_used_len);
#ifdef DEBUG_STEP
printf(
"free block seq_id: %d, free block num: %d, "
"encoder_block_len: %d, ori_free_list_len: %d\n",
tid,
decoder_used_len,
encoder_block_len,
ori_free_list_len);
#endif
for (int i = 0; i < decoder_used_len; i++) {
free_list[ori_free_list_len + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
encoder_block_lens[tid] = 0;
used_list_len[tid] = 0;
}
}
}
__syncthreads();
if (tid < bsz) {
int *block_table_now = block_tables + tid * block_num_per_seq;
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size;
if (!base_model_stop_flags[tid] && !batch_drop[tid] && max_possible_block_idx < block_num_per_seq &&
block_table_now[max_possible_block_idx] == -1) {
int ori_need_block_len = atomicAdd(&need_block_len, 1);
need_block_list[ori_need_block_len] = tid;
// 统计需要分配block的位置和总数
// const int ori_free_list_len = atomicSub(free_list_len, 1);
// block_table_now[(seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size] =
// free_list[ori_free_list_len - 1];
// used_list_len[tid] += 1;
#ifdef DEBUG_STEP
printf("seq_id: %d need block\n", tid);
#endif
}
}
__syncthreads();
#ifdef DEBUG_STEP
if (tid == 0) {
printf("need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
}
#endif
// 这里直接从 bid 0 开始遍历
while (need_block_len > free_list_len[0]) {
#ifdef DEBUG_STEP
if (tid == 0) {
printf("in while need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
}
#endif
const int used_block_num =
tid < bsz && !base_model_stop_flags[tid] ? used_list_len[tid] : 0;
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
if (tid == 0) {
const int encoder_block_len = encoder_block_lens[kv_pair.key];
int *block_table_now =
block_tables + kv_pair.key * block_num_per_seq;
for (int i = 0; i < kv_pair.value; i++) {
free_list[free_list_len[0] + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
const int ori_free_list_len = atomicAdd(free_list_len, kv_pair.value);
printf(
"MTP STEP need_block_len: %d. free_list_len: %d."
"Drop bid: %d, free block num: %d, "
"encoder_block_len: %d,"
"After drop free_list_len %d \n",
need_block_len,
ori_free_list_len,
kv_pair.key,
kv_pair.value,
encoder_block_len,
free_list_len[0]);
stop_flags[kv_pair.key] = true;
batch_drop[kv_pair.key] = true;
seq_lens_this_time[kv_pair.key] = 0;
seq_lens_decoder[kv_pair.key] = 0;
used_list_len[kv_pair.key] = 0;
}
__syncthreads();
}
if (tid < need_block_len) {
const int need_block_id = need_block_list[tid];
// 这里必须用 batch_drop, 不能用 stop_flags
if (!batch_drop[need_block_id]) {
used_list_len[need_block_id] += 1;
const int ori_free_list_len = atomicSub(free_list_len, 1);
int *block_table_now =
block_tables + need_block_id * block_num_per_seq;
#ifdef DEBUG_STEP
printf("bid: %d allocate block_id %d. seq_lens_decoder:%d \n", need_block_id, free_list[ori_free_list_len - 1], seq_lens_decoder[need_block_id]);
#endif
block_table_now[(seq_lens_decoder[need_block_id] +
max_draft_tokens + 1) /
block_size] = free_list[ori_free_list_len - 1];
}
}
}
void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags,
const paddle::Tensor &batch_drop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const int block_size,
const int max_draft_tokens) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
constexpr int BlockSize = 512; // bsz <= 256
#ifdef DEBUG_STEP
printf(
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
"%d\n",
bsz,
block_num_per_seq,
length,
max_decoder_block_num);
#endif
mtp_free_and_dispatch_block<BlockSize, BlockSize><<<1, BlockSize, 0, cu_stream>>>(
const_cast<bool *>(base_model_stop_flags.data<bool>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(batch_drop.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_draft_tokens);
}
PD_BUILD_STATIC_OP(mtp_step_paddle)
.Inputs({"base_model_stop_flags",
"stop_flags",
"batch_drop",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"used_list_len",
"free_list",
"free_list_len"})
.Attrs({"block_size: int",
"max_draft_tokens: int"})
.Outputs({"block_tables_out",
"stop_flags_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out"})
.SetInplaceMap({{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"}})
.SetKernelFn(PD_KERNEL(MTPStepPaddle));