mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] Support block scheduler v1 for FD (#2928)
* Support FD block scheduler v1 * Support FD block scheduler v1 * Support FD block scheduler v1 * Fix according to copilot review * Fix according to review * Remove is_dummy * Fix bug when real_bsz=1 * Fix infer first token cost time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -284,6 +284,32 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step);
|
||||
|
||||
void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size);
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size);
|
||||
|
||||
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
const paddle::Tensor &token_nums_per_expert);
|
||||
@@ -941,6 +967,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
*/
|
||||
m.def("update_inputs", &UpdateInputes, "update_inputs function");
|
||||
|
||||
/**
|
||||
* update_inputs_v1.cu
|
||||
* update_inputs_v1
|
||||
*/
|
||||
m.def("update_inputs_v1", &UpdateInputesV1, "update inputs for scheduler v1 function");
|
||||
|
||||
/**
|
||||
* recover_decode_task.cu
|
||||
* recover_decode_task
|
||||
*/
|
||||
m.def("recover_decode_task", &RecoverDecodeTask, "recover decode task for scheduler v1 function");
|
||||
|
||||
/**
|
||||
* extract_text_token_output.cu
|
||||
* extract_text_token_output
|
||||
|
91
custom_ops/gpu_ops/recover_decode_task.cu
Normal file
91
custom_ops/gpu_ops/recover_decode_task.cu
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
.Inputs({"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"block_tables",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"stop_flags_out",
|
||||
"is_block_step_out"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"is_block_step", "is_block_step_out"}})
|
||||
.SetKernelFn(PD_KERNEL(RecoverDecodeTask));
|
176
custom_ops/gpu_ops/update_inputs_v1.cu
Normal file
176
custom_ops/gpu_ops/update_inputs_v1.cu
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
if (thread_idx < max_bsz) {
|
||||
if (thread_idx < bsz) {
|
||||
stop_flag_now = stop_flags[thread_idx];
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
if (thread_idx < bsz) {
|
||||
if(stop_flag_now) {
|
||||
seq_lens_this_time[thread_idx] = 0; // stop at next step
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else
|
||||
{
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
topk_ids[thread_idx] = -1;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (thread_idx == 0) {
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(update_inputs_v1)
|
||||
.Inputs({"stop_flags",
|
||||
"not_need_stop",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"prompt_lens",
|
||||
"topk_ids",
|
||||
"input_ids",
|
||||
"block_tables",
|
||||
"stop_nums",
|
||||
"next_tokens",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Outputs({"not_need_stop_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"topk_ids_out",
|
||||
"input_ids_out",
|
||||
"stop_flags_out",
|
||||
"is_block_step_out"})
|
||||
.SetInplaceMap({{"not_need_stop", "not_need_stop_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"topk_ids", "topk_ids_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"is_block_step", "is_block_step_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputesV1));
|
Reference in New Issue
Block a user