[XPU] bind some OPs for VL model with pybind (#4522)

This commit is contained in:
Lucas
2025-10-27 10:50:08 +08:00
committed by GitHub
parent cdc40cdc2a
commit 5c6105f4a2
8 changed files with 1789 additions and 1087 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -15,93 +15,97 @@
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void update_inputs_kernel(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
__global__ void update_inputs_kernel(bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* input_ids,
const int64_t* stop_nums,
const bool* stop_flags,
const bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride) {
int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
bool stop_flag_now = false;
int64_t stop_flag_now_int = 0;
if (thread_idx < max_bsz) {
if (thread_idx < bsz) {
stop_flag_now = stop_flags[thread_idx];
if (is_block_step[thread_idx]) {
stop_flag_now_int = 0;
} else {
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
}
} else {
stop_flag_now_int = 1;
}
}
bool stop_flag_now = false;
int64_t stop_flag_now_int = 0;
if (thread_idx < max_bsz) {
if (thread_idx < bsz) {
const int seq_len_this_time = seq_lens_this_time[thread_idx];
const int seq_len_encoder = seq_lens_encoder[thread_idx];
const int seq_len_decoder = seq_lens_decoder[thread_idx];
seq_lens_decoder[thread_idx] = stop_flag_now ?
0 : (seq_len_encoder > 0 ?
(seq_len_encoder + seq_len_decoder) : seq_len_decoder + 1);
seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1;
seq_lens_encoder[thread_idx] = 0;
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
input_ids_now[0] = next_tokens[thread_idx];
}
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (thread_idx == 0) {
not_need_stop[0] = stop_sum < stop_nums[0];
stop_flag_now = stop_flags[thread_idx];
if (is_block_step[thread_idx]) {
stop_flag_now_int = 0;
} else {
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
}
} else {
stop_flag_now_int = 1;
}
}
if (thread_idx < bsz) {
const int seq_len_this_time = seq_lens_this_time[thread_idx];
const int seq_len_encoder = seq_lens_encoder[thread_idx];
const int seq_len_decoder = seq_lens_decoder[thread_idx];
seq_lens_decoder[thread_idx] =
stop_flag_now
? 0
: (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder)
: seq_len_decoder + 1);
seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1;
seq_lens_encoder[thread_idx] = 0;
int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride;
input_ids_now[0] = next_tokens[thread_idx];
}
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (thread_idx == 0) {
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &input_ids,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(
input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
auto cu_stream = input_ids.stream();
#endif
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(update_inputs)
@@ -124,4 +128,4 @@ PD_BUILD_STATIC_OP(update_inputs)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"input_ids", "input_ids_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputes));
.SetKernelFn(PD_KERNEL(UpdateInputs));

View File

@@ -15,15 +15,14 @@
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void update_inputs_beam_kernel(
int *seq_lens_this_time,
int *seq_lens_encoder,
int64_t *input_ids,
float *logits,
const int bsz,
const int seq_len,
const int hidden_size,
const int beam_width) {
__global__ void update_inputs_beam_kernel(int* seq_lens_this_time,
int* seq_lens_encoder,
int64_t* input_ids,
float* logits,
const int bsz,
const int seq_len,
const int hidden_size,
const int beam_width) {
int thread_idx = threadIdx.x;
int block_idx = blockIdx.x;
@@ -35,23 +34,22 @@ __global__ void update_inputs_beam_kernel(
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
}
if (block_idx < seq_len) {
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
input_ids[thread_idx * seq_len + block_idx] =
input_ids[bsz_index * seq_len + block_idx];
}
logits[thread_idx * hidden_size + block_idx] = logits[bsz_index * hidden_size + block_idx];
logits[thread_idx * hidden_size + block_idx] =
logits[bsz_index * hidden_size + block_idx];
}
}
__syncthreads();
}
void UpdateInputesBeam(
const paddle::Tensor& beam_width,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& logits) {
void UpdateInputsBeam(const paddle::Tensor& beam_width,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& logits) {
int beam_width_scalar = beam_width.data<int>()[0];
if (beam_width_scalar > 1) {
@@ -59,16 +57,16 @@ void UpdateInputesBeam(
const int seq_len = input_ids.shape()[1];
const int hidden_size = logits.shape()[1];
update_inputs_beam_kernel<1024><<<hidden_size, 1024, 0, input_ids.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<float*>(logits.data<float>()),
bsz,
seq_len,
hidden_size,
beam_width_scalar
);
update_inputs_beam_kernel<1024>
<<<hidden_size, 1024, 0, input_ids.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<float*>(logits.data<float>()),
bsz,
seq_len,
hidden_size,
beam_width_scalar);
}
}
@@ -86,4 +84,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"input_ids", "input_ids_out"},
{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
.SetKernelFn(PD_KERNEL(UpdateInputsBeam));

View File

@@ -15,146 +15,150 @@
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void update_inputs_kernel_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size,
bool prefill_one_step_stop) {
int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__global__ void update_inputs_kernel_v1(bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size,
bool prefill_one_step_stop) {
int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
bool stop_flag_now = false;
int64_t stop_flag_now_int = 0;
if (thread_idx < max_bsz) {
if (thread_idx < bsz) {
stop_flag_now = stop_flags[thread_idx];
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
} else {
stop_flag_now_int = 1;
}
}
bool stop_flag_now = false;
int64_t stop_flag_now_int = 0;
if (thread_idx < max_bsz) {
if (thread_idx < bsz) {
if(stop_flag_now) {
seq_lens_this_time[thread_idx] = 0; // stop at next step
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
stop_flag_now = stop_flags[thread_idx];
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
} else {
stop_flag_now_int = 1;
}
}
if (thread_idx < bsz) {
if (stop_flag_now) {
seq_lens_this_time[thread_idx] = 0; // stop at next step
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
} else {
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >=
prompt_lens[thread_idx]) {
if (prefill_one_step_stop) {
// prefill done, stop
stop_flags[thread_idx] = true;
seq_lens_this_time[thread_idx] = 0;
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
stop_flag_now_int = 1;
} else {
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
if (prefill_one_step_stop) {
// prefill done, stop
stop_flags[thread_idx] = true;
seq_lens_this_time[thread_idx] = 0;
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
stop_flag_now_int = 1;
} else{
// decoding
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
seq_lens_this_time[thread_idx] = 1;
seq_lens_encoder[thread_idx] = 0;
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
input_ids_now[0] = next_tokens[thread_idx];
// decoding
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
seq_lens_this_time[thread_idx] = 1;
seq_lens_encoder[thread_idx] = 0;
int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride;
input_ids_now[0] = next_tokens[thread_idx];
// to judge whether block is not enough
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
// should be scheduled by server
is_block_step[thread_idx] = true;
seq_lens_this_time[thread_idx]= 0;
stop_flags[thread_idx] = true;
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
seq_lens_decoder[thread_idx] = 0;
stop_flag_now_int = 1;
}
}
} else
{
stop_flags[thread_idx] = true;
seq_lens_this_time[thread_idx] = 0;
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
topk_ids[thread_idx] = -1;
stop_flag_now_int = 1;
}
// to judge whether block is not enough
int* block_table_now = block_tables + thread_idx * block_num_per_seq;
if (seq_lens_this_time[thread_idx] != 0 &&
block_table_now[seq_lens_decoder[thread_idx] / block_size] ==
-1) {
// should be scheduled by server
is_block_step[thread_idx] = true;
seq_lens_this_time[thread_idx] = 0;
stop_flags[thread_idx] = true;
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
seq_lens_decoder[thread_idx] = 0;
stop_flag_now_int = 1;
}
}
} else {
stop_flags[thread_idx] = true;
seq_lens_this_time[thread_idx] = 0;
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
topk_ids[thread_idx] = -1;
stop_flag_now_int = 1;
}
}
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (thread_idx == 0) {
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (thread_idx == 0) {
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void UpdateInputesV1(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &topk_ids,
const paddle::Tensor &input_ids,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step,
const int block_size) {
void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& topk_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step,
const int block_size) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(
input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
auto cu_stream = input_ids.stream();
#endif
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
bool prefill_one_step_stop = false;
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int *>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int*>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t*>(prompt_lens.data<int64_t>()),
const_cast<int64_t*>(topk_ids.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<int*>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(update_inputs_v1)
@@ -190,4 +194,4 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
{"stop_flags", "stop_flags_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"is_block_step", "is_block_step_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesV1));
.SetKernelFn(PD_KERNEL(UpdateInputsV1));

View File

@@ -33,6 +33,20 @@ void prof_start();
void prof_stop();
std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& encoder_batch_idx,
const paddle::Tensor& decoder_batch_idx,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_idx_cpu,
const paddle::Tensor& decoder_batch_idx_cpu,
const paddle::Tensor& enc_batch_tensor,
const paddle::Tensor& dec_batch_tensor,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length);
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
const paddle::Tensor& seq_lens_this_time_tensor,
const paddle::Tensor& seq_lens_decoder_tensor,
@@ -73,6 +87,21 @@ std::vector<paddle::Tensor> BlockAttn(
const std::string& pos_emb_type = "NORMAL",
bool rope_3d = false);
std::vector<paddle::Tensor> MoeLayer(
const paddle::Tensor& x,
const paddle::Tensor& gate_weight,
const paddle::optional<paddle::Tensor>& gate_correction_bias,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_weight_scale,
const paddle::optional<paddle::Tensor>& down_proj_weight_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const std::string& quant_method,
const int moe_top_k,
const bool moe_group);
std::vector<paddle::Tensor> MoERedundantTopKSelect(
const paddle::Tensor& gating_logits,
const paddle::Tensor& expert_id_to_ep_rank_array,
@@ -294,6 +323,65 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx);
std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& enc_batch_tensor,
const paddle::Tensor& dec_batch_tensor,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length);
std::vector<paddle::Tensor> GetImgBoundaries(
const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id);
std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables,
int block_size);
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len);
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens,
const paddle::Tensor& end_ids,
const paddle::Tensor& next_tokens,
const bool beam_search);
void RecoverDecodeTask(const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_seq_lens_decoder,
const paddle::Tensor& block_tables,
const paddle::Tensor& is_block_step,
const int block_size);
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor& input,
const std::string shm_name,
const std::vector<int>& shape,
bool use_ipc);
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
@@ -308,6 +396,31 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder);
void StepPaddle(const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& ori_seq_lens_encoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor& encoder_block_lens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& step_block_list,
const paddle::Tensor& step_lens,
const paddle::Tensor& recover_block_list,
const paddle::Tensor& recover_lens,
const paddle::Tensor& need_block_list,
const paddle::Tensor& need_block_len,
const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len,
const paddle::Tensor& input_ids,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& next_tokens,
const paddle::Tensor& first_token_ids,
const int block_size,
const int encoder_decoder_block_num);
void MTPStepPaddle(
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& stop_flags,
@@ -323,6 +436,17 @@ void MTPStepPaddle(
const int block_size,
const int max_draft_tokens);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);
void SaveOutMmsgDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank);
void SpeculateStepSchedule(
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
@@ -356,7 +480,78 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name);
void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);
void TextImageIndexOut(const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index);
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id);
void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step);
void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& topk_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step,
const int block_size);
std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor& x,
const std::string& algo,
const int32_t arch,
const int32_t group_size);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("adjust_batch",
&AdjustBatch,
py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"),
py::arg("encoder_batch_idx"),
py::arg("decoder_batch_idx"),
py::arg("encoder_seq_lod_cpu"),
py::arg("encoder_batch_idx_cpu"),
py::arg("decoder_batch_idx_cpu"),
py::arg("enc_batch_tensor"),
py::arg("dec_batch_tensor"),
py::arg("output_padding_offset"),
py::arg("max_input_length"),
"adjust batch in XPU");
m.def("block_attn",
&BlockAttn,
py::arg("qkv"),
@@ -388,96 +583,107 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("pos_emb_type") = "NORMAL",
py::arg("rope_3d") = false,
"block attention in XPU");
m.def("create_kv_signal_sender",
&create_cachekv_signal_thread,
"init write cache kv signal thread");
m.def("cuda_host_alloc",
&custom_xpu_host_alloc,
"Allocate pinned memory",
py::arg("size"),
py::arg("flags") = 0x00);
m.def("cuda_host_free",
&custom_xpu_host_free,
"Free pinned memory",
py::arg("ptr"));
m.def("get_peer_mem_addr",
&xpu_get_peer_mem_addr,
"Get Host memory address of device pointer",
py::arg("ptr"));
m.def("cuda_host_register",
&xpu_cuda_host_register,
"Register pinned memory",
py::arg("ptr"),
py::arg("size"),
py::arg("flags") = cudaHostRegisterDefault);
m.def("create_kv_signal_sender",
&create_cachekv_signal_thread,
"init write cache kv signal thread");
m.def("destroy_kv_signal_sender",
&destroy_cachekv_signal_thread,
"write cache kv signal thread exit");
m.def("prof_start", &prof_start, "prof_start");
m.def("prof_stop", &prof_stop, "prof_stop");
m.def("moe_redundant_topk_select",
&MoERedundantTopKSelect,
py::arg("gating_logits"),
py::arg("expert_id_to_ep_rank_array"),
py::arg("expert_in_rank_num_list"),
py::arg("tokens_per_expert_stats_list"),
py::arg("bias"),
py::arg("moe_topk"),
py::arg("apply_norm_weight"),
py::arg("enable_softmax_top_k_fused"),
py::arg("redundant_ep_rank_num_plus_one"),
"moe export RedundantTopKSelect function");
m.def("set_ncluster", &set_ncluster, "set ncluster");
/**
* open_shm_and_get_meta_signal.cc
* InitKVSingnalPerQuery
*/
m.def("init_kv_signal_per_query",
&InitKVSignalPerQuery,
py::arg("seq_lens_encoder_tensor"),
py::arg("seq_lens_this_time_tensor"),
py::arg("seq_lens_decoder_tensor"),
py::arg("rank"),
py::arg("num_layers"),
"init_kv_signal_per_query function");
m.def("draft_model_preprocess",
&DraftModelPreprocess,
py::arg("draft_tokens"),
py::arg("input_ids"),
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_idx"),
py::arg("seq_lens_encoder_record"),
py::arg("seq_lens_decoder_record"),
py::arg("not_need_stop"),
py::arg("batch_drop"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_seq_lens_decoder"),
py::arg("base_model_step_idx"),
py::arg("base_model_stop_flags"),
py::arg("base_model_is_block_step"),
py::arg("base_model_draft_tokens"),
py::arg("max_draft_token"),
py::arg("truncate_first_token"),
py::arg("splitwise_prefill"),
"Preprocess data for draft model in speculative decoding");
/**
* GetOutputKVSignal
*/
m.def("get_output_kv_signal",
&GetOutputKVSignal,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output_kv_signal function");
m.def("draft_model_postprocess",
&DraftModelPostprocess,
py::arg("base_model_draft_tokens"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_stop_flags"),
"Postprocess data for draft model in speculative decoding");
m.def("fused_rms_norm_xpu",
&RmsNorm,
"Fused RMS normalization for XPU",
py::arg("x"), // 输入张量
py::arg("bias"), // 偏置(可选)
py::arg("residual"), // 残差连接(可选)
py::arg("norm_weight"), // 归一化权重
py::arg("norm_bias"), // 归一化偏置(可选)
py::arg("epsilon"), // 数值稳定项
py::arg("begin_norm_axis"), // 归一化起始维度
py::arg("quant_scale"), // 量化缩放因子
py::arg("quant_round_type"), // 量化舍入类型
py::arg("quant_max_bound"), // 量化最大值边界
py::arg("quant_min_bound") // 量化最小值边界
m.def("draft_model_update",
&DraftModelUpdate,
"Update draft model states during speculative decoding",
py::arg("inter_next_tokens"), // 中间next tokens张量
py::arg("draft_tokens"), // 草稿token张量
py::arg("pre_ids"), // 前置ID张量
py::arg("seq_lens_this_time"), // 当前步骤序列长度张量
py::arg("seq_lens_encoder"), // 编码器序列长度张量
py::arg("seq_lens_decoder"), // 解码器序列长度张量
py::arg("step_idx"), // 步骤索引张量
py::arg("output_cum_offsets"), // 输出累积偏移量张
py::arg("stop_flags"), // 停止标志张量
py::arg("not_need_stop"), // 无需停止标志张量
py::arg("max_dec_len"), // 最大解码长度张量
py::arg("end_ids"), // 结束ID张量
py::arg("base_model_draft_tokens"), // 基础模型草稿token张量
py::arg("max_seq_len"), // 最大序列长度int
py::arg("substep") // 子步骤编号int
);
m.def("weight_only_linear_xpu",
&WeightOnlyLinear,
"Weight-only quantized linear layer",
py::arg("x"),
py::arg("weight"),
py::arg("weight_scale"),
py::arg("bias"),
py::arg("weight_dtype"),
py::arg("arch"),
py::arg("group_size") = -1);
m.def("eagle_get_hidden_states",
&EagleGetHiddenStates,
py::arg("input"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("stop_flags"),
py::arg("accept_nums"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("actual_draft_token_num"),
"Get draft model hidden states");
m.def("eagle_get_self_hidden_states",
&EagleGetSelfHiddenStates,
py::arg("input"),
py::arg("last_seq_lens_this_time"),
py::arg("seq_lens_this_time"),
py::arg("step_idx"),
"Rebuild draft model hidden states");
m.def("ep_moe_expert_combine",
&MoeEPCombine,
@@ -502,6 +708,157 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("token_nums_this_rank"),
py::arg("quant_method"));
m.def("fused_rms_norm_xpu",
&RmsNorm,
"Fused RMS normalization for XPU",
py::arg("x"), // 输入张量
py::arg("bias"), // 偏置(可选)
py::arg("residual"), // 残差连接(可选)
py::arg("norm_weight"), // 归一化权重
py::arg("norm_bias"), // 归一化偏置(可选)
py::arg("epsilon"), // 数值稳定项
py::arg("begin_norm_axis"), // 归一化起始维度
py::arg("quant_scale"), // 量化缩放因子
py::arg("quant_round_type"), // 量化舍入类型
py::arg("quant_max_bound"), // 量化最大值边界
py::arg("quant_min_bound") // 量化最小值边界
);
m.def("gather_next_token",
&GatherNextToken,
py::arg("tmp_out"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"),
py::arg("encoder_batch_map"),
py::arg("decoder_batch_map"),
py::arg("encoder_seq_lod_cpu"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("enc_batch_tensor"),
py::arg("dec_batch_tensor"),
py::arg("output_padding_offset"),
py::arg("max_input_length"),
"Gather next token for XPU");
m.def("get_img_boundaries",
&GetImgBoundaries,
py::arg("task_input_ids"),
py::arg("grid_thw"),
py::arg("image_patch_id"),
"Get image boundaries in VL model");
m.def("get_infer_param",
&GetInferParam,
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("seq_lens_this_time"),
py::arg("block_tables"),
py::arg("block_size"),
"Get infer parameters for block attention in XPU");
m.def("get_peer_mem_addr",
&xpu_get_peer_mem_addr,
"Get Host memory address of device pointer",
py::arg("ptr"));
m.def("get_token_penalty_multi_scores",
&TokenPenaltyMultiScores,
py::arg("pre_ids"),
py::arg("logits"),
py::arg("penalty_scores"),
py::arg("frequency_scores"),
py::arg("presence_scores"),
py::arg("temperatures"),
py::arg("bad_tokens"),
py::arg("cur_len"),
py::arg("min_len"),
py::arg("eos_token_id"),
"get token_penalty_multi_scores function");
m.def("get_output",
&GetOutputStatic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output function");
m.def("get_output_ep",
&GetOutputStatic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output_ep function");
m.def("get_output_dynamic",
&GetOutputDynamic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
py::arg("msg_queue_id"),
"get_output_dynamic function");
m.def("get_output_ep_dynamic",
&GetOutputDynamic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
py::arg("msg_queue_id"),
"get_output_ep_dynamic function");
m.def("get_output_kv_signal",
&GetOutputKVSignal,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output_kv_signal function");
m.def("get_padding_offset",
&GetPaddingOffset,
py::arg("input_ids"),
py::arg("cum_offsets"),
py::arg("token_num"),
py::arg("seq_len"),
"get padding offset function");
m.def("init_kv_signal_per_query",
&InitKVSignalPerQuery,
py::arg("seq_lens_encoder_tensor"),
py::arg("seq_lens_this_time_tensor"),
py::arg("seq_lens_decoder_tensor"),
py::arg("rank"),
py::arg("num_layers"),
"init_kv_signal_per_query function");
m.def("moe_redundant_topk_select",
&MoERedundantTopKSelect,
py::arg("gating_logits"),
py::arg("expert_id_to_ep_rank_array"),
py::arg("expert_in_rank_num_list"),
py::arg("tokens_per_expert_stats_list"),
py::arg("bias"),
py::arg("moe_topk"),
py::arg("apply_norm_weight"),
py::arg("enable_softmax_top_k_fused"),
py::arg("redundant_ep_rank_num_plus_one"),
"moe export RedundantTopKSelect function");
m.def("mtp_step_paddle",
&MTPStepPaddle,
py::arg("base_model_stop_flags"),
py::arg("stop_flags"),
py::arg("batch_drop"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"), // [bsz, block_num_per_seq]
py::arg("encoder_block_lens"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("block_size"),
py::arg("max_draft_tokens"),
"MTP step paddle");
m.def("moe_expert_ffn",
&MoeExpertFFN,
"MoE expert feed-forward network with quantization support",
@@ -529,25 +886,46 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("moe_topk"),
py::arg("apply_norm_weight"));
m.def("draft_model_update",
&DraftModelUpdate,
"Update draft model states during speculative decoding",
py::arg("inter_next_tokens"), // 中间next tokens张量
py::arg("draft_tokens"), // 草稿token张量
py::arg("pre_ids"), // 前置ID张量
py::arg("seq_lens_this_time"), // 当前步骤序列长度张量
py::arg("seq_lens_encoder"), // 编码器序列长度张量
py::arg("seq_lens_decoder"), // 解码器序列长度张量
py::arg("step_idx"), // 步骤索引张量
py::arg("output_cum_offsets"), // 输出累积偏移量张量
py::arg("stop_flags"), // 停止标志张量
py::arg("not_need_stop"), // 无需停止标志张量
py::arg("max_dec_len"), // 最大解码长度张量
py::arg("end_ids"), // 结束ID张量
py::arg("base_model_draft_tokens"), // 基础模型草稿token张量
py::arg("max_seq_len"), // 最大序列长度int
py::arg("substep") // 子步骤编号int
);
m.def("prof_start", &prof_start, "prof_start");
m.def("prof_stop", &prof_stop, "prof_stop");
m.def("recover_decode_task",
&RecoverDecodeTask,
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_seq_lens_decoder"),
py::arg("block_tables"),
py::arg("is_block_step"),
py::arg("block_size"),
"Recover decode task function");
m.def("save_output",
&SaveOutMmsgStatic,
py::arg("x"),
py::arg("not_need_stop"),
py::arg("rank_id"),
py::arg("save_each_rank"),
"Save output function");
m.def("save_output_dynamic",
&SaveOutMmsgDynamic,
py::arg("x"),
py::arg("not_need_stop"),
py::arg("rank_id"),
py::arg("msg_queue_id"),
py::arg("save_each_rank"),
"Save output dynamic function");
m.def("share_external_data",
&ShareExternalData,
py::arg("input"),
py::arg("shm_name"),
py::arg("shape"),
py::arg("use_ipc"),
"Share external data function");
m.def("speculate_get_token_penalty_multi_scores",
&SpeculateTokenPenaltyMultiScores,
@@ -582,15 +960,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("stop_nums"),
"Update speculative decoding states (V3)");
m.def("top_p_candidates",
&TopPCandidates,
py::arg("probs"),
py::arg("top_p"),
py::arg("output_padding_offset"),
py::arg("candidates_len"),
py::arg("max_seq_len"),
"Generate top-p candidates based on probability distributions");
m.def("speculate_verify",
&SpeculateVerify,
py::arg("accept_tokens"),
@@ -633,61 +1002,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("step_idx"),
"Set values based on flags and indices in speculative decoding");
m.def("draft_model_preprocess",
&DraftModelPreprocess,
py::arg("draft_tokens"),
py::arg("input_ids"),
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_idx"),
py::arg("seq_lens_encoder_record"),
py::arg("seq_lens_decoder_record"),
py::arg("not_need_stop"),
py::arg("batch_drop"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_seq_lens_decoder"),
py::arg("base_model_step_idx"),
py::arg("base_model_stop_flags"),
py::arg("base_model_is_block_step"),
py::arg("base_model_draft_tokens"),
py::arg("max_draft_token"),
py::arg("truncate_first_token"),
py::arg("splitwise_prefill"),
"Preprocess data for draft model in speculative decoding");
m.def("draft_model_postprocess",
&DraftModelPostprocess,
py::arg("base_model_draft_tokens"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_stop_flags"),
"Postprocess data for draft model in speculative decoding");
m.def("eagle_get_hidden_states",
&EagleGetHiddenStates,
py::arg("input"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("stop_flags"),
py::arg("accept_nums"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("actual_draft_token_num"),
"Get draft model hidden states");
m.def("eagle_get_self_hidden_states",
&EagleGetSelfHiddenStates,
py::arg("input"),
py::arg("last_seq_lens_this_time"),
py::arg("seq_lens_this_time"),
py::arg("step_idx"),
"Rebuild draft model hidden states");
m.def("speculate_get_output_padding_offset",
&SpeculateGetOutputPaddingOffset,
py::arg("output_cum_offsets_tmp"),
@@ -706,23 +1020,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_encoder"),
"Get padding offset");
m.def("mtp_step_paddle",
&MTPStepPaddle,
py::arg("base_model_stop_flags"),
py::arg("stop_flags"),
py::arg("batch_drop"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"), // [bsz, block_num_per_seq]
py::arg("encoder_block_lens"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("block_size"),
py::arg("max_draft_tokens"),
"MTP step paddle");
m.def("speculate_step_reschedule",
&SpeculateStepSchedule,
py::arg("stop_flags"),
@@ -760,6 +1057,147 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_decoder"),
"Get sequence lengths output");
m.def("set_data_ipc",
&SetDataIpc,
py::arg("tmp_input"),
py::arg("shm_name"),
"Set data IPC function");
m.def("set_ncluster", &set_ncluster, "set ncluster");
m.def("set_stop_value_multi_ends",
&GetStopFlagsMulti,
py::arg("topk_ids"),
py::arg("stop_flags"),
py::arg("seq_lens"),
py::arg("end_ids"),
py::arg("next_tokens"),
py::arg("beam_search"),
"Set stop value multi ends function");
m.def("step_paddle",
&StepPaddle,
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("ori_seq_lens_encoder"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"),
py::arg("encoder_block_lens"),
py::arg("is_block_step"),
py::arg("step_block_list"),
py::arg("step_lens"),
py::arg("recover_block_list"),
py::arg("recover_lens"),
py::arg("need_block_list"),
py::arg("need_block_len"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("input_ids"),
py::arg("pre_ids"),
py::arg("step_idx"),
py::arg("next_tokens"),
py::arg("first_token_ids"),
py::arg("block_size"),
py::arg("encoder_decoder_block_num"),
"Step paddle function");
m.def("text_image_gather_scatter",
&TextImageGatherScatter,
py::arg("input"),
py::arg("text_input"),
py::arg("image_input"),
py::arg("token_type_ids"),
py::arg("text_index"),
py::arg("image_index"),
py::arg("is_scatter"),
"Scatter image and text from hidden states, or gather them to hidden "
"states");
m.def("text_image_index_out",
&TextImageIndexOut,
py::arg("token_type_ids"),
py::arg("text_index"),
py::arg("image_index"),
"Generate index for text and image");
m.def("top_p_candidates",
&TopPCandidates,
py::arg("probs"),
py::arg("top_p"),
py::arg("output_padding_offset"),
py::arg("candidates_len"),
py::arg("max_seq_len"),
"Generate top-p candidates based on probability distributions");
m.def("update_inputs",
&UpdateInputs,
py::arg("stop_flags"),
py::arg("not_need_stop"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("input_ids"),
py::arg("stop_nums"),
py::arg("next_tokens"),
py::arg("is_block_step"),
"Update inputs function");
m.def("update_inputs_v1",
&UpdateInputsV1,
py::arg("stop_flags"),
py::arg("not_need_stop"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_seq_lens_decoder"),
py::arg("prompt_lens"),
py::arg("topk_ids"),
py::arg("input_ids"),
py::arg("block_tables"),
py::arg("stop_nums"),
py::arg("next_tokens"),
py::arg("is_block_step"),
py::arg("block_size"),
"Update inputs v1 function");
m.def("weight_quantize_xpu",
&WeightQuantize,
py::arg("x"),
py::arg("algo"),
py::arg("arch"),
py::arg("group_size"),
"Quantize weights on XPU");
m.def("weight_only_linear_xpu",
&WeightOnlyLinear,
"Weight-only quantized linear layer",
py::arg("x"),
py::arg("weight"),
py::arg("weight_scale"),
py::arg("bias"),
py::arg("weight_dtype"),
py::arg("arch"),
py::arg("group_size") = -1);
m.def("xpu_moe_layer",
&MoeLayer,
py::arg("x"),
py::arg("gate_weight"),
py::arg("gate_correction_bias"),
py::arg("up_gate_proj_weight"),
py::arg("down_proj_weight"),
py::arg("up_gate_proj_bias"),
py::arg("down_proj_bias"),
py::arg("up_gate_proj_weight_scale"),
py::arg("down_proj_weight_scale"),
py::arg("down_proj_in_scale"),
py::arg("quant_method"),
py::arg("moe_top_k"),
py::arg("moe_group"),
"fused moe op(topk + dispatch + ffn + combine) in XPU");
// 添加XPU错误信息的异常处理类
py::register_exception<XPUError>(m, "XPUError");
}

View File

@@ -17,18 +17,18 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &input_ids,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int max_bsz = stop_flags.shape()[0];
PADDLE_ENFORCE_LE(
@@ -42,11 +42,11 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
int r = baidu::xpu::api::plugin::update_inputs(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<bool*>(not_need_stop_xpu.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
@@ -57,7 +57,7 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed.");
auto not_need_stop_cpu =
not_need_stop_xpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
@@ -81,4 +81,4 @@ PD_BUILD_OP(update_inputs)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"input_ids", "input_ids_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputes));
.SetKernelFn(PD_KERNEL(UpdateInputs));

View File

@@ -17,23 +17,23 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void UpdateInputesV1(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &topk_ids,
const paddle::Tensor &input_ids,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step,
const int block_size) {
void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& topk_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
@@ -43,18 +43,18 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::update_inputs_v1(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int*>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t*>(prompt_lens.data<int64_t>()),
const_cast<int64_t*>(topk_ids.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<int*>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
@@ -64,7 +64,7 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
@@ -101,4 +101,4 @@ PD_BUILD_OP(update_inputs_v1)
{"stop_flags", "stop_flags_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"is_block_step", "is_block_step_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesV1));
.SetKernelFn(PD_KERNEL(UpdateInputsV1));

View File

@@ -27,6 +27,7 @@ from fastdeploy.model_executor.ops.xpu import (
moe_expert_ffn,
moe_topk_select,
weight_quantize_xpu,
xpu_moe_layer,
)
@@ -153,8 +154,6 @@ class XPUMoEMethod(MoEMethodBase):
"""
Apply TP Fused Op.
"""
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
fused_moe_out = xpu_moe_layer(
x,
gate.weight.transpose([1, 0]),