mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] bind some OPs for VL model with pybind (#4522)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -15,93 +15,97 @@
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void update_inputs_kernel(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
__global__ void update_inputs_kernel(bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* input_ids,
|
||||
const int64_t* stop_nums,
|
||||
const bool* stop_flags,
|
||||
const bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
if (thread_idx < max_bsz) {
|
||||
if (thread_idx < bsz) {
|
||||
stop_flag_now = stop_flags[thread_idx];
|
||||
if (is_block_step[thread_idx]) {
|
||||
stop_flag_now_int = 0;
|
||||
} else {
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
if (thread_idx < max_bsz) {
|
||||
if (thread_idx < bsz) {
|
||||
const int seq_len_this_time = seq_lens_this_time[thread_idx];
|
||||
const int seq_len_encoder = seq_lens_encoder[thread_idx];
|
||||
const int seq_len_decoder = seq_lens_decoder[thread_idx];
|
||||
|
||||
seq_lens_decoder[thread_idx] = stop_flag_now ?
|
||||
0 : (seq_len_encoder > 0 ?
|
||||
(seq_len_encoder + seq_len_decoder) : seq_len_decoder + 1);
|
||||
|
||||
seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (thread_idx == 0) {
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
stop_flag_now = stop_flags[thread_idx];
|
||||
if (is_block_step[thread_idx]) {
|
||||
stop_flag_now_int = 0;
|
||||
} else {
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
if (thread_idx < bsz) {
|
||||
const int seq_len_this_time = seq_lens_this_time[thread_idx];
|
||||
const int seq_len_encoder = seq_lens_encoder[thread_idx];
|
||||
const int seq_len_decoder = seq_lens_decoder[thread_idx];
|
||||
|
||||
seq_lens_decoder[thread_idx] =
|
||||
stop_flag_now
|
||||
? 0
|
||||
: (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder)
|
||||
: seq_len_decoder + 1);
|
||||
|
||||
seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (thread_idx == 0) {
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
void UpdateInputs(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // only on cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(
|
||||
input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(update_inputs)
|
||||
@@ -124,4 +128,4 @@ PD_BUILD_STATIC_OP(update_inputs)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"input_ids", "input_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
||||
|
||||
@@ -15,15 +15,14 @@
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void update_inputs_beam_kernel(
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int64_t *input_ids,
|
||||
float *logits,
|
||||
const int bsz,
|
||||
const int seq_len,
|
||||
const int hidden_size,
|
||||
const int beam_width) {
|
||||
__global__ void update_inputs_beam_kernel(int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int64_t* input_ids,
|
||||
float* logits,
|
||||
const int bsz,
|
||||
const int seq_len,
|
||||
const int hidden_size,
|
||||
const int beam_width) {
|
||||
int thread_idx = threadIdx.x;
|
||||
int block_idx = blockIdx.x;
|
||||
|
||||
@@ -35,23 +34,22 @@ __global__ void update_inputs_beam_kernel(
|
||||
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
|
||||
}
|
||||
if (block_idx < seq_len) {
|
||||
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
|
||||
input_ids[thread_idx * seq_len + block_idx] =
|
||||
input_ids[bsz_index * seq_len + block_idx];
|
||||
}
|
||||
|
||||
logits[thread_idx * hidden_size + block_idx] = logits[bsz_index * hidden_size + block_idx];
|
||||
|
||||
logits[thread_idx * hidden_size + block_idx] =
|
||||
logits[bsz_index * hidden_size + block_idx];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
void UpdateInputesBeam(
|
||||
const paddle::Tensor& beam_width,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& logits) {
|
||||
void UpdateInputsBeam(const paddle::Tensor& beam_width,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& logits) {
|
||||
int beam_width_scalar = beam_width.data<int>()[0];
|
||||
|
||||
if (beam_width_scalar > 1) {
|
||||
@@ -59,16 +57,16 @@ void UpdateInputesBeam(
|
||||
const int seq_len = input_ids.shape()[1];
|
||||
const int hidden_size = logits.shape()[1];
|
||||
|
||||
update_inputs_beam_kernel<1024><<<hidden_size, 1024, 0, input_ids.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<float*>(logits.data<float>()),
|
||||
bsz,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
beam_width_scalar
|
||||
);
|
||||
update_inputs_beam_kernel<1024>
|
||||
<<<hidden_size, 1024, 0, input_ids.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<float*>(logits.data<float>()),
|
||||
bsz,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
beam_width_scalar);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,4 +84,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"logits", "logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputsBeam));
|
||||
|
||||
@@ -15,146 +15,150 @@
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__global__ void update_inputs_kernel_v1(bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
if (thread_idx < max_bsz) {
|
||||
if (thread_idx < bsz) {
|
||||
stop_flag_now = stop_flags[thread_idx];
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
bool stop_flag_now = false;
|
||||
int64_t stop_flag_now_int = 0;
|
||||
if (thread_idx < max_bsz) {
|
||||
if (thread_idx < bsz) {
|
||||
if(stop_flag_now) {
|
||||
seq_lens_this_time[thread_idx] = 0; // stop at next step
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now = stop_flags[thread_idx];
|
||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
if (thread_idx < bsz) {
|
||||
if (stop_flag_now) {
|
||||
seq_lens_this_time[thread_idx] = 0; // stop at next step
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >=
|
||||
prompt_lens[thread_idx]) {
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else{
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
topk_ids[thread_idx] = -1;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
// to judge whether block is not enough
|
||||
int* block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 &&
|
||||
block_table_now[seq_lens_decoder[thread_idx] / block_size] ==
|
||||
-1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
topk_ids[thread_idx] = -1;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (thread_idx == 0) {
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
if (thread_idx == 0) {
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // only on cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const int block_size) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(
|
||||
input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int*>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t*>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<int*>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(update_inputs_v1)
|
||||
@@ -190,4 +194,4 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"is_block_step", "is_block_step_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputesV1));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputsV1));
|
||||
|
||||
@@ -33,6 +33,20 @@ void prof_start();
|
||||
|
||||
void prof_stop();
|
||||
|
||||
std::vector<paddle::Tensor> AdjustBatch(
|
||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_idx,
|
||||
const paddle::Tensor& decoder_batch_idx,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_idx_cpu,
|
||||
const paddle::Tensor& decoder_batch_idx_cpu,
|
||||
const paddle::Tensor& enc_batch_tensor,
|
||||
const paddle::Tensor& dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_input_length);
|
||||
|
||||
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
|
||||
const paddle::Tensor& seq_lens_this_time_tensor,
|
||||
const paddle::Tensor& seq_lens_decoder_tensor,
|
||||
@@ -73,6 +87,21 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const std::string& pos_emb_type = "NORMAL",
|
||||
bool rope_3d = false);
|
||||
|
||||
std::vector<paddle::Tensor> MoeLayer(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::optional<paddle::Tensor>& gate_correction_bias,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_weight_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const std::string& quant_method,
|
||||
const int moe_top_k,
|
||||
const bool moe_group);
|
||||
|
||||
std::vector<paddle::Tensor> MoERedundantTopKSelect(
|
||||
const paddle::Tensor& gating_logits,
|
||||
const paddle::Tensor& expert_id_to_ep_rank_array,
|
||||
@@ -294,6 +323,65 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& step_idx);
|
||||
|
||||
std::vector<paddle::Tensor> GatherNextToken(
|
||||
const paddle::Tensor& tmp_out, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_map,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& enc_batch_tensor,
|
||||
const paddle::Tensor& dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_input_length);
|
||||
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(
|
||||
const paddle::Tensor& task_input_ids,
|
||||
const paddle::Tensor& grid_thw,
|
||||
const int64_t image_patch_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
int block_size);
|
||||
|
||||
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
|
||||
|
||||
void GetOutputDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len);
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& end_ids,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const bool beam_search);
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_seq_lens_decoder,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const int block_size);
|
||||
|
||||
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor& input,
|
||||
const std::string shm_name,
|
||||
const std::vector<int>& shape,
|
||||
bool use_ipc);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
@@ -308,6 +396,31 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
|
||||
void StepPaddle(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& ori_seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor& encoder_block_lens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& step_block_list,
|
||||
const paddle::Tensor& step_lens,
|
||||
const paddle::Tensor& recover_block_list,
|
||||
const paddle::Tensor& recover_lens,
|
||||
const paddle::Tensor& need_block_list,
|
||||
const paddle::Tensor& need_block_len,
|
||||
const paddle::Tensor& used_list_len,
|
||||
const paddle::Tensor& free_list,
|
||||
const paddle::Tensor& free_list_len,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& first_token_ids,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num);
|
||||
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -323,6 +436,17 @@ void MTPStepPaddle(
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
void SaveOutMmsgDynamic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank);
|
||||
|
||||
void SpeculateStepSchedule(
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
@@ -356,7 +480,78 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index);
|
||||
|
||||
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id);
|
||||
|
||||
void UpdateInputs(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step);
|
||||
|
||||
void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // only on cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const int block_size);
|
||||
|
||||
std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor& x,
|
||||
const std::string& algo,
|
||||
const int32_t arch,
|
||||
const int32_t group_size);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("adjust_batch",
|
||||
&AdjustBatch,
|
||||
py::arg("x"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("encoder_batch_idx"),
|
||||
py::arg("decoder_batch_idx"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("encoder_batch_idx_cpu"),
|
||||
py::arg("decoder_batch_idx_cpu"),
|
||||
py::arg("enc_batch_tensor"),
|
||||
py::arg("dec_batch_tensor"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("max_input_length"),
|
||||
"adjust batch in XPU");
|
||||
|
||||
m.def("block_attn",
|
||||
&BlockAttn,
|
||||
py::arg("qkv"),
|
||||
@@ -388,96 +583,107 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("pos_emb_type") = "NORMAL",
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention in XPU");
|
||||
|
||||
m.def("create_kv_signal_sender",
|
||||
&create_cachekv_signal_thread,
|
||||
"init write cache kv signal thread");
|
||||
|
||||
m.def("cuda_host_alloc",
|
||||
&custom_xpu_host_alloc,
|
||||
"Allocate pinned memory",
|
||||
py::arg("size"),
|
||||
py::arg("flags") = 0x00);
|
||||
|
||||
m.def("cuda_host_free",
|
||||
&custom_xpu_host_free,
|
||||
"Free pinned memory",
|
||||
py::arg("ptr"));
|
||||
m.def("get_peer_mem_addr",
|
||||
&xpu_get_peer_mem_addr,
|
||||
"Get Host memory address of device pointer",
|
||||
py::arg("ptr"));
|
||||
|
||||
m.def("cuda_host_register",
|
||||
&xpu_cuda_host_register,
|
||||
"Register pinned memory",
|
||||
py::arg("ptr"),
|
||||
py::arg("size"),
|
||||
py::arg("flags") = cudaHostRegisterDefault);
|
||||
m.def("create_kv_signal_sender",
|
||||
&create_cachekv_signal_thread,
|
||||
"init write cache kv signal thread");
|
||||
|
||||
m.def("destroy_kv_signal_sender",
|
||||
&destroy_cachekv_signal_thread,
|
||||
"write cache kv signal thread exit");
|
||||
m.def("prof_start", &prof_start, "prof_start");
|
||||
m.def("prof_stop", &prof_stop, "prof_stop");
|
||||
m.def("moe_redundant_topk_select",
|
||||
&MoERedundantTopKSelect,
|
||||
py::arg("gating_logits"),
|
||||
py::arg("expert_id_to_ep_rank_array"),
|
||||
py::arg("expert_in_rank_num_list"),
|
||||
py::arg("tokens_per_expert_stats_list"),
|
||||
py::arg("bias"),
|
||||
py::arg("moe_topk"),
|
||||
py::arg("apply_norm_weight"),
|
||||
py::arg("enable_softmax_top_k_fused"),
|
||||
py::arg("redundant_ep_rank_num_plus_one"),
|
||||
"moe export RedundantTopKSelect function");
|
||||
m.def("set_ncluster", &set_ncluster, "set ncluster");
|
||||
|
||||
/**
|
||||
* open_shm_and_get_meta_signal.cc
|
||||
* InitKVSingnalPerQuery
|
||||
*/
|
||||
m.def("init_kv_signal_per_query",
|
||||
&InitKVSignalPerQuery,
|
||||
py::arg("seq_lens_encoder_tensor"),
|
||||
py::arg("seq_lens_this_time_tensor"),
|
||||
py::arg("seq_lens_decoder_tensor"),
|
||||
py::arg("rank"),
|
||||
py::arg("num_layers"),
|
||||
"init_kv_signal_per_query function");
|
||||
m.def("draft_model_preprocess",
|
||||
&DraftModelPreprocess,
|
||||
py::arg("draft_tokens"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("seq_lens_encoder_record"),
|
||||
py::arg("seq_lens_decoder_record"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("batch_drop"),
|
||||
py::arg("accept_tokens"),
|
||||
py::arg("accept_num"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("base_model_seq_lens_decoder"),
|
||||
py::arg("base_model_step_idx"),
|
||||
py::arg("base_model_stop_flags"),
|
||||
py::arg("base_model_is_block_step"),
|
||||
py::arg("base_model_draft_tokens"),
|
||||
py::arg("max_draft_token"),
|
||||
py::arg("truncate_first_token"),
|
||||
py::arg("splitwise_prefill"),
|
||||
"Preprocess data for draft model in speculative decoding");
|
||||
|
||||
/**
|
||||
* GetOutputKVSignal
|
||||
*/
|
||||
m.def("get_output_kv_signal",
|
||||
&GetOutputKVSignal,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
"get_output_kv_signal function");
|
||||
m.def("draft_model_postprocess",
|
||||
&DraftModelPostprocess,
|
||||
py::arg("base_model_draft_tokens"),
|
||||
py::arg("base_model_seq_lens_this_time"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("base_model_stop_flags"),
|
||||
"Postprocess data for draft model in speculative decoding");
|
||||
|
||||
m.def("fused_rms_norm_xpu",
|
||||
&RmsNorm,
|
||||
"Fused RMS normalization for XPU",
|
||||
py::arg("x"), // 输入张量
|
||||
py::arg("bias"), // 偏置(可选)
|
||||
py::arg("residual"), // 残差连接(可选)
|
||||
py::arg("norm_weight"), // 归一化权重
|
||||
py::arg("norm_bias"), // 归一化偏置(可选)
|
||||
py::arg("epsilon"), // 数值稳定项
|
||||
py::arg("begin_norm_axis"), // 归一化起始维度
|
||||
py::arg("quant_scale"), // 量化缩放因子
|
||||
py::arg("quant_round_type"), // 量化舍入类型
|
||||
py::arg("quant_max_bound"), // 量化最大值边界
|
||||
py::arg("quant_min_bound") // 量化最小值边界
|
||||
m.def("draft_model_update",
|
||||
&DraftModelUpdate,
|
||||
"Update draft model states during speculative decoding",
|
||||
py::arg("inter_next_tokens"), // 中间next tokens张量
|
||||
py::arg("draft_tokens"), // 草稿token张量
|
||||
py::arg("pre_ids"), // 前置ID张量
|
||||
py::arg("seq_lens_this_time"), // 当前步骤序列长度张量
|
||||
py::arg("seq_lens_encoder"), // 编码器序列长度张量
|
||||
py::arg("seq_lens_decoder"), // 解码器序列长度张量
|
||||
py::arg("step_idx"), // 步骤索引张量
|
||||
py::arg("output_cum_offsets"), // 输出累积偏移量张量
|
||||
py::arg("stop_flags"), // 停止标志张量
|
||||
py::arg("not_need_stop"), // 无需停止标志张量
|
||||
py::arg("max_dec_len"), // 最大解码长度张量
|
||||
py::arg("end_ids"), // 结束ID张量
|
||||
py::arg("base_model_draft_tokens"), // 基础模型草稿token张量
|
||||
py::arg("max_seq_len"), // 最大序列长度(int)
|
||||
py::arg("substep") // 子步骤编号(int)
|
||||
);
|
||||
|
||||
m.def("weight_only_linear_xpu",
|
||||
&WeightOnlyLinear,
|
||||
"Weight-only quantized linear layer",
|
||||
py::arg("x"),
|
||||
py::arg("weight"),
|
||||
py::arg("weight_scale"),
|
||||
py::arg("bias"),
|
||||
py::arg("weight_dtype"),
|
||||
py::arg("arch"),
|
||||
py::arg("group_size") = -1);
|
||||
m.def("eagle_get_hidden_states",
|
||||
&EagleGetHiddenStates,
|
||||
py::arg("input"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("accept_nums"),
|
||||
py::arg("base_model_seq_lens_this_time"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("actual_draft_token_num"),
|
||||
"Get draft model hidden states");
|
||||
|
||||
m.def("eagle_get_self_hidden_states",
|
||||
&EagleGetSelfHiddenStates,
|
||||
py::arg("input"),
|
||||
py::arg("last_seq_lens_this_time"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("step_idx"),
|
||||
"Rebuild draft model hidden states");
|
||||
|
||||
m.def("ep_moe_expert_combine",
|
||||
&MoeEPCombine,
|
||||
@@ -502,6 +708,157 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("token_nums_this_rank"),
|
||||
py::arg("quant_method"));
|
||||
|
||||
m.def("fused_rms_norm_xpu",
|
||||
&RmsNorm,
|
||||
"Fused RMS normalization for XPU",
|
||||
py::arg("x"), // 输入张量
|
||||
py::arg("bias"), // 偏置(可选)
|
||||
py::arg("residual"), // 残差连接(可选)
|
||||
py::arg("norm_weight"), // 归一化权重
|
||||
py::arg("norm_bias"), // 归一化偏置(可选)
|
||||
py::arg("epsilon"), // 数值稳定项
|
||||
py::arg("begin_norm_axis"), // 归一化起始维度
|
||||
py::arg("quant_scale"), // 量化缩放因子
|
||||
py::arg("quant_round_type"), // 量化舍入类型
|
||||
py::arg("quant_max_bound"), // 量化最大值边界
|
||||
py::arg("quant_min_bound") // 量化最小值边界
|
||||
);
|
||||
|
||||
m.def("gather_next_token",
|
||||
&GatherNextToken,
|
||||
py::arg("tmp_out"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("encoder_batch_map"),
|
||||
py::arg("decoder_batch_map"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("enc_batch_tensor"),
|
||||
py::arg("dec_batch_tensor"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("max_input_length"),
|
||||
"Gather next token for XPU");
|
||||
|
||||
m.def("get_img_boundaries",
|
||||
&GetImgBoundaries,
|
||||
py::arg("task_input_ids"),
|
||||
py::arg("grid_thw"),
|
||||
py::arg("image_patch_id"),
|
||||
"Get image boundaries in VL model");
|
||||
|
||||
m.def("get_infer_param",
|
||||
&GetInferParam,
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("block_size"),
|
||||
"Get infer parameters for block attention in XPU");
|
||||
|
||||
m.def("get_peer_mem_addr",
|
||||
&xpu_get_peer_mem_addr,
|
||||
"Get Host memory address of device pointer",
|
||||
py::arg("ptr"));
|
||||
|
||||
m.def("get_token_penalty_multi_scores",
|
||||
&TokenPenaltyMultiScores,
|
||||
py::arg("pre_ids"),
|
||||
py::arg("logits"),
|
||||
py::arg("penalty_scores"),
|
||||
py::arg("frequency_scores"),
|
||||
py::arg("presence_scores"),
|
||||
py::arg("temperatures"),
|
||||
py::arg("bad_tokens"),
|
||||
py::arg("cur_len"),
|
||||
py::arg("min_len"),
|
||||
py::arg("eos_token_id"),
|
||||
"get token_penalty_multi_scores function");
|
||||
|
||||
m.def("get_output",
|
||||
&GetOutputStatic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
"get_output function");
|
||||
|
||||
m.def("get_output_ep",
|
||||
&GetOutputStatic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
"get_output_ep function");
|
||||
|
||||
m.def("get_output_dynamic",
|
||||
&GetOutputDynamic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
py::arg("msg_queue_id"),
|
||||
"get_output_dynamic function");
|
||||
|
||||
m.def("get_output_ep_dynamic",
|
||||
&GetOutputDynamic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
py::arg("msg_queue_id"),
|
||||
"get_output_ep_dynamic function");
|
||||
|
||||
m.def("get_output_kv_signal",
|
||||
&GetOutputKVSignal,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
"get_output_kv_signal function");
|
||||
|
||||
m.def("get_padding_offset",
|
||||
&GetPaddingOffset,
|
||||
py::arg("input_ids"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("token_num"),
|
||||
py::arg("seq_len"),
|
||||
"get padding offset function");
|
||||
|
||||
m.def("init_kv_signal_per_query",
|
||||
&InitKVSignalPerQuery,
|
||||
py::arg("seq_lens_encoder_tensor"),
|
||||
py::arg("seq_lens_this_time_tensor"),
|
||||
py::arg("seq_lens_decoder_tensor"),
|
||||
py::arg("rank"),
|
||||
py::arg("num_layers"),
|
||||
"init_kv_signal_per_query function");
|
||||
|
||||
m.def("moe_redundant_topk_select",
|
||||
&MoERedundantTopKSelect,
|
||||
py::arg("gating_logits"),
|
||||
py::arg("expert_id_to_ep_rank_array"),
|
||||
py::arg("expert_in_rank_num_list"),
|
||||
py::arg("tokens_per_expert_stats_list"),
|
||||
py::arg("bias"),
|
||||
py::arg("moe_topk"),
|
||||
py::arg("apply_norm_weight"),
|
||||
py::arg("enable_softmax_top_k_fused"),
|
||||
py::arg("redundant_ep_rank_num_plus_one"),
|
||||
"moe export RedundantTopKSelect function");
|
||||
|
||||
m.def("mtp_step_paddle",
|
||||
&MTPStepPaddle,
|
||||
py::arg("base_model_stop_flags"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("batch_drop"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("block_tables"), // [bsz, block_num_per_seq]
|
||||
py::arg("encoder_block_lens"),
|
||||
py::arg("used_list_len"),
|
||||
py::arg("free_list"),
|
||||
py::arg("free_list_len"),
|
||||
py::arg("block_size"),
|
||||
py::arg("max_draft_tokens"),
|
||||
"MTP step paddle");
|
||||
|
||||
m.def("moe_expert_ffn",
|
||||
&MoeExpertFFN,
|
||||
"MoE expert feed-forward network with quantization support",
|
||||
@@ -529,25 +886,46 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("moe_topk"),
|
||||
py::arg("apply_norm_weight"));
|
||||
|
||||
m.def("draft_model_update",
|
||||
&DraftModelUpdate,
|
||||
"Update draft model states during speculative decoding",
|
||||
py::arg("inter_next_tokens"), // 中间next tokens张量
|
||||
py::arg("draft_tokens"), // 草稿token张量
|
||||
py::arg("pre_ids"), // 前置ID张量
|
||||
py::arg("seq_lens_this_time"), // 当前步骤序列长度张量
|
||||
py::arg("seq_lens_encoder"), // 编码器序列长度张量
|
||||
py::arg("seq_lens_decoder"), // 解码器序列长度张量
|
||||
py::arg("step_idx"), // 步骤索引张量
|
||||
py::arg("output_cum_offsets"), // 输出累积偏移量张量
|
||||
py::arg("stop_flags"), // 停止标志张量
|
||||
py::arg("not_need_stop"), // 无需停止标志张量
|
||||
py::arg("max_dec_len"), // 最大解码长度张量
|
||||
py::arg("end_ids"), // 结束ID张量
|
||||
py::arg("base_model_draft_tokens"), // 基础模型草稿token张量
|
||||
py::arg("max_seq_len"), // 最大序列长度(int)
|
||||
py::arg("substep") // 子步骤编号(int)
|
||||
);
|
||||
m.def("prof_start", &prof_start, "prof_start");
|
||||
|
||||
m.def("prof_stop", &prof_stop, "prof_stop");
|
||||
|
||||
m.def("recover_decode_task",
|
||||
&RecoverDecodeTask,
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("step_seq_lens_decoder"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("block_size"),
|
||||
"Recover decode task function");
|
||||
|
||||
m.def("save_output",
|
||||
&SaveOutMmsgStatic,
|
||||
py::arg("x"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("save_each_rank"),
|
||||
"Save output function");
|
||||
|
||||
m.def("save_output_dynamic",
|
||||
&SaveOutMmsgDynamic,
|
||||
py::arg("x"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("msg_queue_id"),
|
||||
py::arg("save_each_rank"),
|
||||
"Save output dynamic function");
|
||||
|
||||
m.def("share_external_data",
|
||||
&ShareExternalData,
|
||||
py::arg("input"),
|
||||
py::arg("shm_name"),
|
||||
py::arg("shape"),
|
||||
py::arg("use_ipc"),
|
||||
"Share external data function");
|
||||
|
||||
m.def("speculate_get_token_penalty_multi_scores",
|
||||
&SpeculateTokenPenaltyMultiScores,
|
||||
@@ -582,15 +960,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("stop_nums"),
|
||||
"Update speculative decoding states (V3)");
|
||||
|
||||
m.def("top_p_candidates",
|
||||
&TopPCandidates,
|
||||
py::arg("probs"),
|
||||
py::arg("top_p"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("candidates_len"),
|
||||
py::arg("max_seq_len"),
|
||||
"Generate top-p candidates based on probability distributions");
|
||||
|
||||
m.def("speculate_verify",
|
||||
&SpeculateVerify,
|
||||
py::arg("accept_tokens"),
|
||||
@@ -633,61 +1002,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("step_idx"),
|
||||
"Set values based on flags and indices in speculative decoding");
|
||||
|
||||
m.def("draft_model_preprocess",
|
||||
&DraftModelPreprocess,
|
||||
py::arg("draft_tokens"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("seq_lens_encoder_record"),
|
||||
py::arg("seq_lens_decoder_record"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("batch_drop"),
|
||||
py::arg("accept_tokens"),
|
||||
py::arg("accept_num"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("base_model_seq_lens_decoder"),
|
||||
py::arg("base_model_step_idx"),
|
||||
py::arg("base_model_stop_flags"),
|
||||
py::arg("base_model_is_block_step"),
|
||||
py::arg("base_model_draft_tokens"),
|
||||
py::arg("max_draft_token"),
|
||||
py::arg("truncate_first_token"),
|
||||
py::arg("splitwise_prefill"),
|
||||
"Preprocess data for draft model in speculative decoding");
|
||||
|
||||
m.def("draft_model_postprocess",
|
||||
&DraftModelPostprocess,
|
||||
py::arg("base_model_draft_tokens"),
|
||||
py::arg("base_model_seq_lens_this_time"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("base_model_stop_flags"),
|
||||
"Postprocess data for draft model in speculative decoding");
|
||||
|
||||
m.def("eagle_get_hidden_states",
|
||||
&EagleGetHiddenStates,
|
||||
py::arg("input"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("accept_nums"),
|
||||
py::arg("base_model_seq_lens_this_time"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("actual_draft_token_num"),
|
||||
"Get draft model hidden states");
|
||||
|
||||
m.def("eagle_get_self_hidden_states",
|
||||
&EagleGetSelfHiddenStates,
|
||||
py::arg("input"),
|
||||
py::arg("last_seq_lens_this_time"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("step_idx"),
|
||||
"Rebuild draft model hidden states");
|
||||
|
||||
m.def("speculate_get_output_padding_offset",
|
||||
&SpeculateGetOutputPaddingOffset,
|
||||
py::arg("output_cum_offsets_tmp"),
|
||||
@@ -706,23 +1020,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_encoder"),
|
||||
"Get padding offset");
|
||||
|
||||
m.def("mtp_step_paddle",
|
||||
&MTPStepPaddle,
|
||||
py::arg("base_model_stop_flags"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("batch_drop"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("block_tables"), // [bsz, block_num_per_seq]
|
||||
py::arg("encoder_block_lens"),
|
||||
py::arg("used_list_len"),
|
||||
py::arg("free_list"),
|
||||
py::arg("free_list_len"),
|
||||
py::arg("block_size"),
|
||||
py::arg("max_draft_tokens"),
|
||||
"MTP step paddle");
|
||||
|
||||
m.def("speculate_step_reschedule",
|
||||
&SpeculateStepSchedule,
|
||||
py::arg("stop_flags"),
|
||||
@@ -760,6 +1057,147 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_decoder"),
|
||||
"Get sequence lengths output");
|
||||
|
||||
m.def("set_data_ipc",
|
||||
&SetDataIpc,
|
||||
py::arg("tmp_input"),
|
||||
py::arg("shm_name"),
|
||||
"Set data IPC function");
|
||||
|
||||
m.def("set_ncluster", &set_ncluster, "set ncluster");
|
||||
|
||||
m.def("set_stop_value_multi_ends",
|
||||
&GetStopFlagsMulti,
|
||||
py::arg("topk_ids"),
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens"),
|
||||
py::arg("end_ids"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("beam_search"),
|
||||
"Set stop value multi ends function");
|
||||
|
||||
m.def("step_paddle",
|
||||
&StepPaddle,
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("ori_seq_lens_encoder"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("encoder_block_lens"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("step_block_list"),
|
||||
py::arg("step_lens"),
|
||||
py::arg("recover_block_list"),
|
||||
py::arg("recover_lens"),
|
||||
py::arg("need_block_list"),
|
||||
py::arg("need_block_len"),
|
||||
py::arg("used_list_len"),
|
||||
py::arg("free_list"),
|
||||
py::arg("free_list_len"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("pre_ids"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("first_token_ids"),
|
||||
py::arg("block_size"),
|
||||
py::arg("encoder_decoder_block_num"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("text_image_gather_scatter",
|
||||
&TextImageGatherScatter,
|
||||
py::arg("input"),
|
||||
py::arg("text_input"),
|
||||
py::arg("image_input"),
|
||||
py::arg("token_type_ids"),
|
||||
py::arg("text_index"),
|
||||
py::arg("image_index"),
|
||||
py::arg("is_scatter"),
|
||||
"Scatter image and text from hidden states, or gather them to hidden "
|
||||
"states");
|
||||
|
||||
m.def("text_image_index_out",
|
||||
&TextImageIndexOut,
|
||||
py::arg("token_type_ids"),
|
||||
py::arg("text_index"),
|
||||
py::arg("image_index"),
|
||||
"Generate index for text and image");
|
||||
|
||||
m.def("top_p_candidates",
|
||||
&TopPCandidates,
|
||||
py::arg("probs"),
|
||||
py::arg("top_p"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("candidates_len"),
|
||||
py::arg("max_seq_len"),
|
||||
"Generate top-p candidates based on probability distributions");
|
||||
|
||||
m.def("update_inputs",
|
||||
&UpdateInputs,
|
||||
py::arg("stop_flags"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("stop_nums"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("is_block_step"),
|
||||
"Update inputs function");
|
||||
|
||||
m.def("update_inputs_v1",
|
||||
&UpdateInputsV1,
|
||||
py::arg("stop_flags"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("step_seq_lens_decoder"),
|
||||
py::arg("prompt_lens"),
|
||||
py::arg("topk_ids"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("stop_nums"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("block_size"),
|
||||
"Update inputs v1 function");
|
||||
|
||||
m.def("weight_quantize_xpu",
|
||||
&WeightQuantize,
|
||||
py::arg("x"),
|
||||
py::arg("algo"),
|
||||
py::arg("arch"),
|
||||
py::arg("group_size"),
|
||||
"Quantize weights on XPU");
|
||||
|
||||
m.def("weight_only_linear_xpu",
|
||||
&WeightOnlyLinear,
|
||||
"Weight-only quantized linear layer",
|
||||
py::arg("x"),
|
||||
py::arg("weight"),
|
||||
py::arg("weight_scale"),
|
||||
py::arg("bias"),
|
||||
py::arg("weight_dtype"),
|
||||
py::arg("arch"),
|
||||
py::arg("group_size") = -1);
|
||||
|
||||
m.def("xpu_moe_layer",
|
||||
&MoeLayer,
|
||||
py::arg("x"),
|
||||
py::arg("gate_weight"),
|
||||
py::arg("gate_correction_bias"),
|
||||
py::arg("up_gate_proj_weight"),
|
||||
py::arg("down_proj_weight"),
|
||||
py::arg("up_gate_proj_bias"),
|
||||
py::arg("down_proj_bias"),
|
||||
py::arg("up_gate_proj_weight_scale"),
|
||||
py::arg("down_proj_weight_scale"),
|
||||
py::arg("down_proj_in_scale"),
|
||||
py::arg("quant_method"),
|
||||
py::arg("moe_top_k"),
|
||||
py::arg("moe_group"),
|
||||
"fused moe op(topk + dispatch + ffn + combine) in XPU");
|
||||
|
||||
// 添加XPU错误信息的异常处理类
|
||||
py::register_exception<XPUError>(m, "XPUError");
|
||||
}
|
||||
|
||||
@@ -17,18 +17,18 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
void UpdateInputs(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
@@ -42,11 +42,11 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
|
||||
int r = baidu::xpu::api::plugin::update_inputs(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_xpu.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
@@ -57,7 +57,7 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_xpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
@@ -81,4 +81,4 @@ PD_BUILD_OP(update_inputs)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"input_ids", "input_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
||||
|
||||
@@ -17,23 +17,23 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // only on cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const int block_size) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
@@ -43,18 +43,18 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
int r = baidu::xpu::api::plugin::update_inputs_v1(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int*>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(prompt_lens.data<int64_t>()),
|
||||
const_cast<int64_t*>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<int*>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
now_bsz,
|
||||
max_bsz,
|
||||
@@ -64,7 +64,7 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
@@ -101,4 +101,4 @@ PD_BUILD_OP(update_inputs_v1)
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"is_block_step", "is_block_step_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputesV1));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputsV1));
|
||||
|
||||
Reference in New Issue
Block a user