mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void draft_model_update_seq_lens_this_time_kernel(
|
||||
const int64_t* base_model_draft_tokens,
|
||||
int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const bool* base_model_stop_flags,
|
||||
int bsz,
|
||||
int base_model_draft_token_len) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
if (!base_model_stop_flags[tid] &&
|
||||
base_model_seq_lens_encoder[tid] == 0) {
|
||||
const int64_t* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_token_len;
|
||||
int token_num = 0;
|
||||
|
||||
for (int i = 0; i < base_model_draft_token_len; ++i) {
|
||||
if (base_model_draft_tokens_now[i] != -1) {
|
||||
token_num++;
|
||||
}
|
||||
}
|
||||
base_model_seq_lens_this_time[tid] = token_num;
|
||||
} else if (base_model_stop_flags[tid]) {
|
||||
base_model_seq_lens_this_time[tid] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_stop_flags) {
|
||||
int real_bsz = base_model_seq_lens_this_time.shape()[0];
|
||||
auto cu_stream = base_model_seq_lens_this_time.stream();
|
||||
int block_size = (real_bsz + 32 - 1) / 32 * 32;
|
||||
int base_model_draft_token_len = base_model_draft_tokens.shape()[1];
|
||||
draft_model_update_seq_lens_this_time_kernel<<<1,
|
||||
block_size,
|
||||
0,
|
||||
cu_stream>>>(
|
||||
base_model_draft_tokens.data<int64_t>(),
|
||||
const_cast<int*>(base_model_seq_lens_this_time.data<int>()),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
real_bsz,
|
||||
base_model_draft_token_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_postprocess)
|
||||
.Inputs({"base_model_draft_tokens",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_stop_flags"})
|
||||
.Outputs({"base_model_draft_tokens_out",
|
||||
"base_model_seq_lens_this_time_out",
|
||||
"base_model_stop_flags_out"})
|
||||
.SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"},
|
||||
{"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_this_time_out"},
|
||||
{"base_model_stop_flags", "base_model_stop_flags_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPostprocess));
|
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
int base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
|
||||
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
|
||||
not_stop_flag = 1;
|
||||
int seq_len_encoder_record = seq_lens_encoder_record[tid];
|
||||
seq_lens_encoder[tid] = seq_len_encoder_record;
|
||||
seq_lens_encoder_record[tid] = -1;
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_decoder[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = not_stop_flag_sum > 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
auto base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||
auto accept_num_now = accept_num[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
// 处理 base_model recover 逻辑
|
||||
// 1. 已处于 recover 状态
|
||||
// if (batch_drop[tid]) {
|
||||
|
||||
// }
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder_record = seq_lens_encoder_record[tid];
|
||||
seq_lens_encoder[tid] = seq_len_encoder_record;
|
||||
seq_lens_encoder_record[tid] = -1;
|
||||
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
|
||||
seq_lens_decoder_record[tid] = 0;
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
|
||||
}
|
||||
} else if (accept_num_now <=
|
||||
max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// Base Model reject stop
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
step_idx[tid] = base_model_step_idx[tid];
|
||||
} else {
|
||||
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
}
|
||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
draft_tokens_now[0] = modified_token;
|
||||
seq_lens_this_time[tid] = 1;
|
||||
|
||||
} else /*Accept all draft tokens*/ {
|
||||
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
seq_lens_this_time[tid] = 2;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_decoder[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = not_stop_flag_sum > 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& seq_lens_encoder_record,
|
||||
const paddle::Tensor& seq_lens_decoder_record,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_encoder_record.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder_record.data<int>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
.Inputs({"draft_tokens",
|
||||
"input_ids",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"seq_lens_encoder_record",
|
||||
"seq_lens_decoder_record",
|
||||
"not_need_stop",
|
||||
"batch_drop",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_seq_lens_decoder",
|
||||
"base_model_step_idx",
|
||||
"base_model_stop_flags",
|
||||
"base_model_is_block_step",
|
||||
"base_model_draft_tokens"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"input_ids_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"seq_lens_encoder_record_out",
|
||||
"seq_lens_decoder_record_out"})
|
||||
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
|
||||
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void update_pre_ids_kernel(const int64_t* draft_tokens,
|
||||
int64_t* pre_ids_all,
|
||||
const bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int64_t* step_idx,
|
||||
int bs,
|
||||
int pre_id_length,
|
||||
int max_draft_token) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs && seq_lens_this_time[tid] != 0 && !stop_flags[tid]) {
|
||||
int64_t* pre_ids_all_now = pre_ids_all + tid * pre_id_length;
|
||||
const int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
|
||||
const int seq_len_this_time = seq_lens_this_time[tid];
|
||||
if (step_idx[tid] - 1 > 0 /*Decoder Step*/) {
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_all_now[step_idx[tid] - i] =
|
||||
draft_token_now[seq_len_this_time - 1 - i];
|
||||
}
|
||||
} else if (step_idx[tid] == 1 /*Encoder Step*/) {
|
||||
pre_ids_all_now[1] = draft_token_now[0];
|
||||
}
|
||||
seq_lens_this_time[tid] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateDraftModelUpdate(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids_all,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx) {
|
||||
int64_t real_bs = seq_lens_this_time.shape()[0];
|
||||
int64_t pre_id_length = pre_ids_all.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
int64_t max_draft_token = draft_tokens.shape()[1];
|
||||
|
||||
int block_size = (real_bs + 32 - 1) / 32 * 32;
|
||||
update_pre_ids_kernel<<<1, block_size, 0, cu_stream>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int64_t*>(pre_ids_all.data<int64_t>()),
|
||||
stop_flags.data<bool>(),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
step_idx.data<int64_t>(),
|
||||
real_bs,
|
||||
pre_id_length,
|
||||
max_draft_token);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_set_value_by_flags)
|
||||
.Inputs({"draft_tokens",
|
||||
"pre_ids_all",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx"})
|
||||
.Outputs({"pre_ids_all_out"})
|
||||
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateDraftModelUpdate));
|
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* pre_ids,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
const int* output_cum_offsets,
|
||||
bool* stop_flags,
|
||||
bool* not_need_stop,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_ids,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int pre_id_length,
|
||||
const int max_base_model_draft_token,
|
||||
const int end_ids_len,
|
||||
const int max_seq_len,
|
||||
const int substep,
|
||||
const bool prefill_one_step_stop) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
auto* draft_token_now = draft_tokens + tid * max_draft_token;
|
||||
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * max_base_model_draft_token;
|
||||
const int next_tokens_start_id =
|
||||
tid * max_seq_len - output_cum_offsets[tid];
|
||||
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||
auto seq_len_this_time = seq_lens_this_time[tid];
|
||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||
auto seq_len_decoder = seq_lens_decoder[tid];
|
||||
|
||||
// 1. update step_idx && seq_lens_dec
|
||||
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
|
||||
int64_t token_this_time = -1;
|
||||
// decoder step
|
||||
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
|
||||
seq_lens_decoder[tid] += seq_len_this_time;
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
|
||||
}
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
||||
// seq_lens_decoder[tid] = seq_lens_encoder[tid];
|
||||
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
pre_ids_now[1] = token_this_time;
|
||||
step_idx[tid] += 1;
|
||||
draft_token_now[0] = token_this_time;
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
}
|
||||
|
||||
// multi_end
|
||||
if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
// max_dec_len
|
||||
} else if (step_idx[tid] >= max_dec_len[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
draft_token_now[seq_len_this_time - 1] = end_ids[0];
|
||||
base_model_draft_tokens_now[substep + 1] = end_ids[0];
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
|
||||
} else {
|
||||
draft_token_now[0] = -1;
|
||||
base_model_draft_tokens_now[substep + 1] = -1;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
|
||||
// 2. set end
|
||||
if (!stop_flags[tid]) {
|
||||
seq_lens_this_time[tid] = 1;
|
||||
} else {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = stop_sum < bsz;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_ids,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_seq_len,
|
||||
const int substep) {
|
||||
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int real_bsz = seq_lens_this_time_shape[0];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
const int end_ids_len = end_ids.shape()[0];
|
||||
const int max_draft_token = draft_tokens.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
const int max_base_model_draft_token = base_model_draft_tokens.shape()[1];
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
|
||||
draft_model_update_kernel<BlockSize><<<1, BlockSize, 0, cu_stream>>>(
|
||||
inter_next_tokens.data<int64_t>(),
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
output_cum_offsets.data<int>(),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
not_need_stop_gpu.data<bool>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_ids.data<int64_t>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
pre_id_length,
|
||||
max_base_model_draft_token,
|
||||
end_ids_len,
|
||||
max_seq_len,
|
||||
substep,
|
||||
prefill_one_step_stop);
|
||||
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_update)
|
||||
.Inputs({"inter_next_tokens",
|
||||
"draft_tokens",
|
||||
"pre_ids",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"output_cum_offsets",
|
||||
"stop_flags",
|
||||
"not_need_stop",
|
||||
"max_dec_len",
|
||||
"end_ids",
|
||||
"base_model_draft_tokens"})
|
||||
.Attrs({"max_seq_len: int", "substep: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"pre_ids_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"stop_flags_out",
|
||||
"not_need_stop_out",
|
||||
"base_model_draft_tokens_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"pre_ids", "pre_ids_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"base_model_draft_tokens", "base_model_draft_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelUpdate));
|
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// #define DEBUG_EAGLE_KERNEL
|
||||
|
||||
__global__ void ComputeOrderKernel(
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* accept_nums,
|
||||
int* position_map,
|
||||
int* output_token_num,
|
||||
const int bsz,
|
||||
const int actual_draft_token_num,
|
||||
const int input_token_num) {
|
||||
int in_offset = 0; // input_offset(long)
|
||||
int out_offset = 0; // output_offset(short)
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
|
||||
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[i];
|
||||
int accept_num = accept_nums[i];
|
||||
int cur_seq_lens_encoder = seq_lens_encoder[i];
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_base_model_seq_lens_this_time%d. cur_seq_lens_this_time%d, accept_num %d\n", i, cur_base_model_seq_lens_this_time, cur_seq_lens_this_time, accept_num);
|
||||
#endif
|
||||
// 1. eagle encoder. Base step=1
|
||||
if (cur_seq_lens_encoder > 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
|
||||
#endif
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
// 2. base model encoder. Base step=0
|
||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||
// 3. New end
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: base=0. draft !=0 \n", i);
|
||||
#endif
|
||||
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
// 4. stopped
|
||||
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
|
||||
} else {
|
||||
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
} else /*Accept all draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||
for (int i = 0; i < output_token_num[0]; i++) {
|
||||
printf("%d ", position_map[i]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void rebuildHiddenStatesKernel(
|
||||
const T* input,
|
||||
const int* position_map,
|
||||
T* out,
|
||||
const int dim_embed,
|
||||
const int elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt; elem_idx += blockDim.x * gridDim.x * VecSize) {
|
||||
|
||||
int ori_token_idx = elem_idx / dim_embed;
|
||||
int token_idx = position_map[ori_token_idx];
|
||||
if (token_idx >= 0) {
|
||||
|
||||
int offset = elem_idx % dim_embed;
|
||||
if (token_idx == 0) {
|
||||
}
|
||||
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const int actual_draft_token_num) {
|
||||
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto input_token_num = input.shape()[0];
|
||||
|
||||
// auto output_token_num = padding_offset.shape()[0];
|
||||
auto dim_embed = input.shape()[1];
|
||||
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
auto position_map = paddle::empty({input_token_num}, seq_lens_this_time.dtype(), input.place());
|
||||
cudaMemsetAsync(position_map.data<int>(), 0xFF, input_token_num * sizeof(seq_lens_this_time.dtype()), seq_lens_this_time.stream());
|
||||
|
||||
auto output_token_num = paddle::empty({1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
ComputeOrderKernel<<<1,1>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
accept_nums.data<int>(),
|
||||
position_map.data<int>(),
|
||||
output_token_num.data<int>(),
|
||||
bsz,
|
||||
actual_draft_token_num,
|
||||
input_token_num);
|
||||
|
||||
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
|
||||
auto out = paddle::empty({output_token_num_cpu, dim_embed}, input.dtype(), input.place());
|
||||
|
||||
constexpr int packSize = VEC_16B / (sizeof(DataType_));
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
|
||||
assert(elem_cnt % packSize == 0);
|
||||
|
||||
int pack_num = elem_cnt / packSize;
|
||||
|
||||
int grid_size = 1;
|
||||
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
constexpr int thread_per_block = 128;
|
||||
|
||||
rebuildHiddenStatesKernel<DataType_, packSize><<<grid_size, thread_per_block>>>(
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>()),
|
||||
position_map.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
dim_embed,
|
||||
elem_cnt);
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const int actual_draft_token_num) {
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
input,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
stop_flags,
|
||||
accept_nums,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
actual_draft_token_num);
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
input,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
stop_flags,
|
||||
accept_nums,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
actual_draft_token_num);
|
||||
}
|
||||
default: {
|
||||
PD_THROW("Not support this data type");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(eagle_get_hidden_states)
|
||||
.Inputs({"input",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"stop_flags",
|
||||
"accept_nums",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder"})
|
||||
.Attrs({"actual_draft_token_num: int"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));
|
@@ -0,0 +1,190 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
// #define DEBUG_EAGLE_KERNEL
|
||||
|
||||
__global__ void computeOrderKernel(
|
||||
const int* last_seq_lens_this_time,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* step_idx,
|
||||
int* src_map,
|
||||
int* output_token_num,
|
||||
int bsz) {
|
||||
int in_offset = 0;
|
||||
int out_offset = 0;
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[i];
|
||||
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_seq_lens_this_time:%d. cur_last_seq_lens_this_time:%d\n", i, cur_seq_lens_this_time, cur_last_seq_lens_this_time);
|
||||
#endif
|
||||
// 1. encoder
|
||||
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d last_step is encoder \n", i);
|
||||
#endif
|
||||
in_offset += 1;
|
||||
src_map[out_offset++] = in_offset - 1;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finish. src_map[%d]=%d \n", i, out_offset - 1, in_offset - 1);
|
||||
#endif
|
||||
// 2. decoder
|
||||
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d is decoder\n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time;
|
||||
src_map[out_offset++] = in_offset - 1;
|
||||
// 3. stop
|
||||
} else {
|
||||
// first token end
|
||||
if (step_idx[i] == 1) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finished in first token \n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
|
||||
// normal end
|
||||
} else {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finished in non-first token \n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||
for (int i = 0; i < output_token_num[0]; i++) {
|
||||
printf("%d ", src_map[i]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int PackSize>
|
||||
__global__ void rebuildSelfHiddenStatesKernel(
|
||||
const T* input,
|
||||
int* src_map,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt) {
|
||||
using LoadT = AlignedVector<T, PackSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int elem_id = global_thread_idx * PackSize; elem_id < elem_cnt; elem_id += blockDim.x * gridDim.x * PackSize) {
|
||||
int output_token_idx = elem_id / dim_embed;
|
||||
int input_token_idx = src_map[output_token_idx];
|
||||
int offset = elem_id % dim_embed;
|
||||
Load<T, PackSize>(input + input_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, PackSize>(src_vec, output + output_token_idx * dim_embed + offset);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor input,
|
||||
const paddle::Tensor last_seq_lens_this_time,
|
||||
const paddle::Tensor seq_lens_this_time,
|
||||
const paddle::Tensor step_idx) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int input_token_num = input.shape()[0];
|
||||
int dim_embed = input.shape()[1];
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
auto src_map = paddle::full({input_token_num}, -1, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
auto output_token_num = paddle::full({1}, 0, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
|
||||
computeOrderKernel<<<1, 1, 0, seq_lens_this_time.stream()>>>(
|
||||
last_seq_lens_this_time.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
src_map.data<int>(),
|
||||
output_token_num.data<int>(),
|
||||
bsz);
|
||||
|
||||
int output_token_num_cpu = output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
|
||||
auto out = paddle::full({output_token_num_cpu, dim_embed}, -1, input.type(), input.place());
|
||||
|
||||
constexpr int packSize = VEC_16B / (sizeof(DataType_));
|
||||
int elem_cnt = output_token_num_cpu * dim_embed;
|
||||
// printf("output_token_num: %d, dim_embed: %d, cnt: %d. packSize: %d\n", output_token_num_cpu, dim_embed,elem_cnt, packSize);
|
||||
assert(elem_cnt % packSize == 0);
|
||||
|
||||
int pack_num = elem_cnt / packSize;
|
||||
|
||||
int grid_size = 1;
|
||||
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
constexpr int threadPerBlock = 128;
|
||||
|
||||
rebuildSelfHiddenStatesKernel<DataType_, packSize><<<grid_size, threadPerBlock, 0, input.stream()>>>(
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>()),
|
||||
src_map.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
dim_embed,
|
||||
elem_cnt);
|
||||
|
||||
|
||||
return {out};
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& last_seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& step_idx) {
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
input,
|
||||
last_seq_lens_this_time,
|
||||
seq_lens_this_time,
|
||||
step_idx);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
input,
|
||||
last_seq_lens_this_time,
|
||||
seq_lens_this_time,
|
||||
step_idx);
|
||||
default:
|
||||
PD_THROW("Not support this data type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
|
||||
.Inputs({"input",
|
||||
"last_seq_lens_this_time",
|
||||
"seq_lens_this_time",
|
||||
"step_idx"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));
|
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void HydraFetchHiddenStatesKernel(const T* output_hidden_states,
|
||||
const int* output_padding_offset,
|
||||
const int* accept_token_num,
|
||||
T* hidden_states,
|
||||
const int bsz,
|
||||
const int max_seq_len,
|
||||
const int hidden_size) {
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + output_padding_offset[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
const int start_ori_token_id = bid * max_seq_len;
|
||||
const int local_token_id = ori_token_id - start_ori_token_id;
|
||||
|
||||
if (local_token_id != accept_token_num[bid] - 1) return;
|
||||
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
|
||||
LoadT vec;
|
||||
|
||||
for (int idx = threadIdx.x * VecSize; idx < hidden_size;
|
||||
idx += blockDim.x * VecSize) {
|
||||
Load(&output_hidden_states[token_id * hidden_size + idx], &vec);
|
||||
Store(vec, &hidden_states[bid * hidden_size + idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> HydraFetchHiddenStatesImpl(
|
||||
const paddle::Tensor& output_hidden_states,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& accept_token_num,
|
||||
const int bsz,
|
||||
const int max_seq_length) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto cu_stream = output_hidden_states.stream();
|
||||
|
||||
auto output_token_num = output_hidden_states.shape()[0];
|
||||
auto hidden_size = output_hidden_states.shape()[1];
|
||||
|
||||
auto hidden_states = paddle::full({bsz, hidden_size},
|
||||
0,
|
||||
output_hidden_states.dtype(),
|
||||
output_hidden_states.place());
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
|
||||
HydraFetchHiddenStatesKernel<data_t, VecSize>
|
||||
<<<output_token_num, 256, 0, cu_stream>>>(
|
||||
output_hidden_states.data<data_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
accept_token_num.data<int>(),
|
||||
hidden_states.data<data_t>(),
|
||||
bsz,
|
||||
max_seq_length,
|
||||
hidden_size);
|
||||
return {hidden_states}; // , enc_token_num, dec_token_num};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> HydraFetchHiddenStates(
|
||||
const paddle::Tensor& output_hidden_states,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& accept_token_num,
|
||||
const int bsz,
|
||||
const int max_seq_length) {
|
||||
switch (output_hidden_states.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return HydraFetchHiddenStatesImpl<paddle::DataType::BFLOAT16>(
|
||||
output_hidden_states,
|
||||
output_padding_offset,
|
||||
accept_token_num,
|
||||
bsz,
|
||||
max_seq_length);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return HydraFetchHiddenStatesImpl<paddle::DataType::FLOAT16>(
|
||||
output_hidden_states,
|
||||
output_padding_offset,
|
||||
accept_token_num,
|
||||
bsz,
|
||||
max_seq_length);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> HydraFetchHiddenStatesInferShape(
|
||||
const std::vector<int64_t>& output_hidden_states_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape,
|
||||
const std::vector<int64_t>& accept_token_num_shape,
|
||||
const int bsz,
|
||||
const int max_seq_length) {
|
||||
return {{bsz, output_hidden_states_shape[1]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> HydraFetchHiddenStatesInferDtype(
|
||||
const paddle::DataType& output_hidden_states_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype,
|
||||
const paddle::DataType& accept_token_num_dtype) {
|
||||
return {output_hidden_states_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(hydra_fetch_hidden_states)
|
||||
.Inputs({"output_hidden_states",
|
||||
"output_padding_offset",
|
||||
"accept_token_num"})
|
||||
.Outputs({"hidden_states"})
|
||||
.Attrs({"bsz: int", "max_seq_length,: int"})
|
||||
.SetKernelFn(PD_KERNEL(HydraFetchHiddenStates))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(HydraFetchHiddenStatesInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(HydraFetchHiddenStatesInferDtype));
|
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
|
||||
};
|
||||
|
||||
void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
int x_dim = x.shape()[1];
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
msg_sed.mtype = 1;
|
||||
bool not_need_stop_data = not_need_stop.data<bool>()[0];
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout
|
||||
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output_key: " << key << std::endl;
|
||||
std::cout << "save msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
|
||||
#endif
|
||||
msg_sed.mtext[i + 2] = 2;
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
|
||||
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: ";
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
std::cout << " " << (int)x_data[2*i] << " ";
|
||||
std::cout << " " << (int)x_data[2*i + 1];
|
||||
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_save_first_token)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t",
|
||||
"save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));
|
@@ -0,0 +1,226 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
template<int NUM_THREADS, int MAX_BATCH_SIZE=256>
|
||||
__global__ void mtp_free_and_dispatch_block(
|
||||
bool *base_model_stop_flags,
|
||||
bool *stop_flags,
|
||||
bool *batch_drop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_draft_tokens) {
|
||||
|
||||
typedef cub::BlockReduce<cub::KeyValuePair<int, int>, NUM_THREADS> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
__shared__ int need_block_len;
|
||||
__shared__ int need_block_list[MAX_BATCH_SIZE];
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
if (tid == 0) {
|
||||
need_block_len = 0;
|
||||
}
|
||||
need_block_list[tid] = 0;
|
||||
int *block_table_now = block_tables + tid * block_num_per_seq;
|
||||
if (base_model_stop_flags[tid] || batch_drop[tid]) {
|
||||
// 回收block块
|
||||
const int encoder_block_len = encoder_block_lens[tid];
|
||||
const int decoder_used_len = used_list_len[tid];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len =
|
||||
atomicAdd(free_list_len, decoder_used_len);
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"free block seq_id: %d, free block num: %d, "
|
||||
"encoder_block_len: %d, ori_free_list_len: %d\n",
|
||||
tid,
|
||||
decoder_used_len,
|
||||
encoder_block_len,
|
||||
ori_free_list_len);
|
||||
#endif
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
free_list[ori_free_list_len + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
encoder_block_lens[tid] = 0;
|
||||
used_list_len[tid] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < bsz) {
|
||||
int *block_table_now = block_tables + tid * block_num_per_seq;
|
||||
int max_possible_block_idx = (seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size;
|
||||
if (!base_model_stop_flags[tid] && !batch_drop[tid] && max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[max_possible_block_idx] == -1) {
|
||||
int ori_need_block_len = atomicAdd(&need_block_len, 1);
|
||||
need_block_list[ori_need_block_len] = tid;
|
||||
// 统计需要分配block的位置和总数
|
||||
// const int ori_free_list_len = atomicSub(free_list_len, 1);
|
||||
// block_table_now[(seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size] =
|
||||
// free_list[ori_free_list_len - 1];
|
||||
// used_list_len[tid] += 1;
|
||||
|
||||
#ifdef DEBUG_STEP
|
||||
printf("seq_id: %d need block\n", tid);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#ifdef DEBUG_STEP
|
||||
if (tid == 0) {
|
||||
printf("need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
|
||||
}
|
||||
#endif
|
||||
// 这里直接从 bid 0 开始遍历
|
||||
while (need_block_len > free_list_len[0]) {
|
||||
#ifdef DEBUG_STEP
|
||||
if (tid == 0) {
|
||||
printf("in while need_block_len:%d, free_list_len: %d\n", need_block_len, free_list_len[0]);
|
||||
}
|
||||
#endif
|
||||
const int used_block_num =
|
||||
tid < bsz && !base_model_stop_flags[tid] ? used_list_len[tid] : 0;
|
||||
cub::KeyValuePair<int, int> kv_pair = {tid, used_block_num};
|
||||
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax());
|
||||
|
||||
if (tid == 0) {
|
||||
const int encoder_block_len = encoder_block_lens[kv_pair.key];
|
||||
int *block_table_now =
|
||||
block_tables + kv_pair.key * block_num_per_seq;
|
||||
for (int i = 0; i < kv_pair.value; i++) {
|
||||
free_list[free_list_len[0] + i] =
|
||||
block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
const int ori_free_list_len = atomicAdd(free_list_len, kv_pair.value);
|
||||
|
||||
printf(
|
||||
"MTP STEP need_block_len: %d. free_list_len: %d."
|
||||
"Drop bid: %d, free block num: %d, "
|
||||
"encoder_block_len: %d,"
|
||||
"After drop free_list_len %d \n",
|
||||
need_block_len,
|
||||
ori_free_list_len,
|
||||
kv_pair.key,
|
||||
kv_pair.value,
|
||||
encoder_block_len,
|
||||
free_list_len[0]);
|
||||
stop_flags[kv_pair.key] = true;
|
||||
batch_drop[kv_pair.key] = true;
|
||||
seq_lens_this_time[kv_pair.key] = 0;
|
||||
seq_lens_decoder[kv_pair.key] = 0;
|
||||
used_list_len[kv_pair.key] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid < need_block_len) {
|
||||
const int need_block_id = need_block_list[tid];
|
||||
// 这里必须用 batch_drop, 不能用 stop_flags
|
||||
if (!batch_drop[need_block_id]) {
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = atomicSub(free_list_len, 1);
|
||||
int *block_table_now =
|
||||
block_tables + need_block_id * block_num_per_seq;
|
||||
#ifdef DEBUG_STEP
|
||||
printf("bid: %d allocate block_id %d. seq_lens_decoder:%d \n", need_block_id, free_list[ori_free_list_len - 1], seq_lens_decoder[need_block_id]);
|
||||
#endif
|
||||
block_table_now[(seq_lens_decoder[need_block_id] +
|
||||
max_draft_tokens + 1) /
|
||||
block_size] = free_list[ori_free_list_len - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor &base_model_stop_flags,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &batch_drop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
constexpr int BlockSize = 512; // bsz <= 256
|
||||
#ifdef DEBUG_STEP
|
||||
printf(
|
||||
"bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: "
|
||||
"%d\n",
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
max_decoder_block_num);
|
||||
#endif
|
||||
mtp_free_and_dispatch_block<BlockSize, BlockSize><<<1, BlockSize, 0, cu_stream>>>(
|
||||
const_cast<bool *>(base_model_stop_flags.data<bool>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(batch_drop.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mtp_step_paddle)
|
||||
.Inputs({"base_model_stop_flags",
|
||||
"stop_flags",
|
||||
"batch_drop",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"block_tables",
|
||||
"encoder_block_lens",
|
||||
"used_list_len",
|
||||
"free_list",
|
||||
"free_list_len"})
|
||||
.Attrs({"block_size: int",
|
||||
"max_draft_tokens: int"})
|
||||
.Outputs({"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out"})
|
||||
.SetInplaceMap({{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"used_list_len", "used_list_len_out"},
|
||||
{"free_list", "free_list_out"},
|
||||
{"free_list_len", "free_list_len_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPStepPaddle));
|
Reference in New Issue
Block a user